Algebraic sovler problems (Laplace approximation in Stan)

Ok. We fixed it. There were two things:

  • We had the wrong target (which is an issue!!)
  • The algebraic solver is quite flakey when it comes to actually solving things.

To do these in order:
1: We had the wrong target

Yes, but it is. That term only contributes a log-determinant. I wrote out the density explicitly for the code below so it looks less weird.

I guess the story is that even though I should know this stuff backwards, I apparently shouldn’t code it quickly!

2: The algebraic solver is flaky.

Instead of initialising to zero, I initialised to a point near the maximum likelihood point. In particular, I took xzero = log( (sums + 0.1) ./ number_of_samples);

It seems that this leads to a stable algebraic solver, at which point everything works. With a bad luck initialisation, you sometimes see one failure, but otherwise it’s fine.

Some things we saw in experiments:

  • The performance depends on sigma. This is worse when you multiply the gradient through by sigma^2. Why?

    • Case 1 (original gradient): The Hessian grows large when sigma gets near zero. This is an area well supported by the prior. The algebraic solver should (and usually does) rein this in and no problems occur. The solver has failed at most once in our tests with the above initialisation.
    • Case 2 (scaled gradient): In this case, the Hessian gets small when sigma is small, which doesn’t allow the optimizer to take large steps. This means it tends to run into the maximum number of steps boundary. We see this happen 5-10 times with the above initialisation.
  • We are sensitive to initialisation for the solve! This is a little surprising because both the function being optimized AND it’s derivative are monotone. This should, in the presence of step-size control, make this thing work. It doesn’t.

    • It almost always fails when initialised to zero. This is a bit weird - it’s not a strange point in the space.
    • It almost always succeeds when initialised as above. With this data those points are -4.61 -4.38 1.97 2.01 0.01 -4.61 -2.15 2.38 3.17 -0.31
      which aren’t really that far from zero.
    • It almost always succeeds (maybe a 1 or 2 failures) if initialised at a +/- 1.0 where the sign is chosen to be consistent with the above.
    • It fails maybe 5-10 times if initialised at +/- 0.1 with the sign chosen the same way.
    • EDIT: If we initialise to the log observed mean, then it also works (modulo occasionally seeing 1 failure).
    • EDIT2: When calculated appropriately, the log observed mean still seems to be a good starting point even when there are unobserved categories.

The output looks to be about correct (roughly consistent with INLA)

Inference for Stan model: output.
1 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=1000.

       mean se_mean   sd  2.5%   25%   50%   75% 97.5% n_eff Rhat
sigma  2.73    0.04 0.69  1.62  2.23  2.66  3.10  4.27   351    1
x[1]  -3.16    0.04 1.37 -6.05 -4.04 -3.13 -2.17 -0.73  1000    1
x[2]  -2.96    0.05 1.42 -5.90 -3.89 -2.86 -1.99 -0.29   921    1
x[3]   1.97    0.00 0.11  1.74  1.89  1.97  2.05  2.20   865    1
x[4]   2.00    0.00 0.12  1.76  1.91  2.00  2.08  2.24   891    1
x[5]   0.02    0.01 0.34 -0.65 -0.22  0.01  0.25  0.69   825    1
x[6]  -3.12    0.05 1.35 -6.03 -3.95 -3.04 -2.19 -0.64   827    1
x[7]  -2.08    0.02 0.66 -3.39 -2.52 -2.07 -1.62 -0.89  1000    1
x[8]   2.37    0.00 0.09  2.19  2.31  2.37  2.43  2.54  1000    1
x[9]   3.16    0.00 0.11  2.96  3.09  3.16  3.23  3.38   965    1
x[10] -0.30    0.01 0.36 -0.95 -0.55 -0.32 -0.04  0.40  1000    1

Samples were drawn using NUTS(diag_e) at Fri Oct 13 14:28:03 2017.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

Now, here’s the rub. This is a fairly easy problem so it was easy to get the initial values close to the conditional mode. How easy this is in general is yet to be seen. Edit: Actually, if the log-sample mean is good enough, then this might not be too bad…

The working code is below. Thanks everyone for your help. Any further comments would be definitely appreciated!

Edit: Simpler initialisation to the log sample mean.

functions {
  vector conditional_grad(vector x, vector sigma, real[] number_of_samples, int[] sums) {
    vector[dims(x)[1]] result;
    result = (to_vector(sums)-to_vector(number_of_samples).*exp(x)) - x/sigma[1]^2;
    return result;
  }
  vector conditional_neg_hessian(vector x,  real sigma, real[] number_of_samples) {
    vector[dims(x)[1]] result;
    result = to_vector(number_of_samples).*exp(x) + 1/sigma^2;
    return result;
  }
}
data {
  int N;
  int M;
  int y[N];
  int<lower=1, upper=M> index[N];
}
transformed data {
  vector[M] xzero = rep_vector(0.0, M);
  real number_of_samples[M];
  int sums[M];
  for (j in 1:M) {
    sums[j] = 0;
    number_of_samples[j]=0.0;
  }
  for (i in 1:N) {
    sums[index[i]] += y[i];
    number_of_samples[index[i]] +=1.0;

  }

 // xzero = log((to_vector(sums) + 0.1) ./ to_vector(number_of_samples));
{ // Beware of empty categories!!!!!!
    int tmp = M;
    real summm=0.0;
    for (i in 1:M) {
      if(number_of_samples[i]==0){
        tmp = tmp-1;
      } else {
        summm = summm + sums[i]/number_of_samples[i];
      }
    }
    xzero = rep_vector(summm/tmp,M);
  }


}
parameters {
  //vector[M] group_mean;
  real<lower=0> sigma;
}
transformed parameters {
  vector[1] sigma_tmp;
  vector[M] conditional_mode;
  sigma_tmp[1] = sigma;
  conditional_mode = algebra_solver(conditional_grad, xzero, sigma_tmp, number_of_samples, sums );
}
model {
  vector[M] laplace_precisions;
  sigma ~ normal(0,2);
  laplace_precisions = conditional_neg_hessian(conditional_mode, sigma,number_of_samples);
  // p(y | x^*) p(x^* |sigma )/p(x^* | sigma, y)
  for (i in 1:N) {
    target += poisson_log_lpmf(y[i] | conditional_mode[index[i]]);
  }
  target += -0.5*dot_self(conditional_mode)/sigma^2 -M*log(sigma) - 0.5*sum(log(laplace_precisions));
}
generated quantities {
  vector[M] x;
  {
    vector[M] laplace_precisions = conditional_neg_hessian(conditional_mode, sigma,number_of_samples);
    for (i in 1:M) {
      x[i] = normal_rng(conditional_mode[i],inv_sqrt(laplace_precisions[i]));
    }
  }
}
1 Like