Algebraic sovler problems (Laplace approximation in Stan)

Soooooooooooooooo

@avehtari and I have been trying to get a Laplace approximation working in Stan. For simple models, all of the components are in place. So it should work. But the algebraic solver is breaking. I’m out of ideas, so anyone who can chime in is very welcome.

The basic idea is this. Assume the following model, which is Poisson data where the log-mean is modelled with an iid gaussian random effect.

The Stan code for the basic model is as follows (Here index is the imaginatively named index to the group each object is in [from 1:M]):

data {
  int N;
  int M;
  int y[N];
  int index[N];
  int<lower=1, upper=M> index[N];
}
parameters {
  vector[M] group_mean;
  real<lower=0> sigma;
}
model {
  group_mean ~ normal(0,sigma);
  sigma ~ normal(0,1);
  for (i in 1:N) {
   y[i] ~ poisson_log(group_mean[index[i]]);
  }
}

This works fine. Everything you would expect happens happens.

Now, let’s try a Laplace approximation. This essentailly replaces the poisson log-likelihood with a normal centred at the maximum of p(group_mean | sigma,y) with a variance given by the inverse of the hessian at the maximum.

The following code implements this. The two derivatives are analytical

functions {

  vector conditional_grad(vector x, vector sigma, real[] number_of_samples, int[] sums) {
    vector[dims(x)[1]] result;
    result = sigma[1]^2*(to_vector(sums)-to_vector(number_of_samples).*exp(x)) - x;
    return result;
  }
  vector conditional_neg_hessian(vector x,  real sigma) {
    vector[dims(x)[1]] result;
    result = exp(x) + 1/sigma^2;
    return result;
  }
}

data {
  int N;
  int M;
  int y[N];
  int<lower=1, upper=M> index[N];
}
transformed data {
  vector[M] xzero = rep_vector(0.0, M);
  real number_of_samples[M];
  int sums[M];

  for (j in 1:M) {
    sums[j] = 0;
    number_of_samples[j]=0.0;
  }

  for (i in 1:N) {
    sums[index[i]] += y[i];
    number_of_samples[index[i]] +=1.0;

    }


}
parameters {
  //vector[M] group_mean;
  real<lower=0> sigma;
}
model {
  vector[M] conditional_mode;
  vector[M] laplace_precisions;
  vector[1] sigma_tmp;
  sigma_tmp[1] = sigma;

  //group_mean ~ normal(0, sigma);
  sigma ~ normal(0,1);

  conditional_mode = algebra_solver(conditional_grad, xzero, sigma_tmp, number_of_samples, sums );

  laplace_precisions = conditional_neg_hessian(conditional_mode, sigma);


  for (i in 1:N) {
    target += -0.5*laplace_precisions[index[i]]*(y[i]- conditional_mode[index[i]])^2;
  }
}

This is a disaster. Now I’m surprised because a poisson-log-normal model is about as nice as you can get (the conditional distribution is log-concave, so the root finding should not have any problems). Instead of this working, I get this error constantly and the sampler doesn’t move.

Iteration: 2000 / 2000 [100%] (Sampling)
Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:
Exception: normal_lpdf: Random variable is nan, but must not be nan! (in ‘/Users/ds921/Desktop/laplace.stan’ at line 50)

If this warning occurs sporadically, such as for highly constrained variable types like covariance matrices, then the sampler is fine,
but if this warning occurs often then your model may be either severely ill-conditioned or misspecified.

Any suggestions?

Files here:
data.R (921 Bytes)
<a class=“attachment” href=“/uploamake_data.R (131 Bytes)
ds/mc_stan/original/2X/e/e5c7815ea7b1c9e1110ee28890c3df080cdf5347.stan”>laplace.stan (1.2 KB)
mmc.stan (253 Bytes)

(the one that didn’t come)
laplace.stan (1.2 KB)

Perhaps try with a prior on sigma that has no density at zero?

Same thing happens if sigma ~ gamma(2,1); is used.

I also know that this problem has a solution (I can run it in INLA and SBAC [ie new cook-gelman-rubin] verifies that INLA works for this problem)

I get some chains to run (with a smaller init_r) and some fail with this message

[2] "  Exception: algebra_solver: the norm of the algebraic function is: 70.2677 but should be lower than the function tolerance: 1e-06. Consider increasing the relative tolerance and the max_num_steps.  (in 'model6f5b66d54378_laplace' at line 53)"

Specifying init = "0" seems to always work, but the chains slow to a crawl once adaptation kicks in.

Hmmmmmm… I’m using cmdstan for my sins (ie I don’t have a version of RStan using 2.17).

But I’m confused as to why the init needs to be so small - the “true” value of sigma is 2 (the posterior mean from the full model is slightly higher), which I would’ve expected to not be badly scaled. On the log scale, the correct posterior from sigma (the only parameter) is in (-0.52,1.1).

My experience so far is that they either fail completely immediately or never move. But again - this is a very easy zero finding problem (the minimum of a log-concave density), so I’m not sure what’s wrong. It also decouples so you can plot the curves and they’re not that weird… (almost linear in a broad neighbourhood of the zero)

I don’t know and I have to go teach now. But if you move conditional_mode and laplace_precisions into transformed parameters and run a few iterations without warmup and ignore the fact that all the transitions immediately diverge and look at the output anyway, it appears as if conditional_mode[5] goes off the rails:

> summary(as.matrix(post))
    sigma[1] conditional_mode[1] conditional_mode[2] conditional_mode[3]
 Min.   :1   Min.   :-1.746      Min.   :-1.606      Min.   :1.947      
 1st Qu.:1   1st Qu.:-1.746      1st Qu.:-1.606      1st Qu.:1.947      
 Median :1   Median :-1.746      Median :-1.606      Median :1.947      
 Mean   :1   Mean   :-1.746      Mean   :-1.606      Mean   :1.947      
 3rd Qu.:1   3rd Qu.:-1.746      3rd Qu.:-1.606      3rd Qu.:1.947      
 Max.   :1   Max.   :-1.746      Max.   :-1.606      Max.   :1.947      
 conditional_mode[4] conditional_mode[5]   conditional_mode[6]
 Min.   :1.978       Min.   :-3.727e-281   Min.   :-1.746     
 1st Qu.:1.978       1st Qu.:  0.000e+00   1st Qu.:-1.746     
 Median :1.978       Median :  0.000e+00   Median :-1.746     
 Mean   :1.978       Mean   : 1.479e-282   Mean   :-1.746     
 3rd Qu.:1.978       3rd Qu.:  0.000e+00   3rd Qu.:-1.746     
 Max.   :1.978       Max.   : 1.479e-280   Max.   :-1.746     
 conditional_mode[7] conditional_mode[8] conditional_mode[9]
 Min.   :-1.607      Min.   :2.361       Min.   :3.134      
 1st Qu.:-1.607      1st Qu.:2.361       1st Qu.:3.134      
 Median :-1.607      Median :2.361       Median :3.134      
 Mean   :-1.607      Mean   :2.361       Mean   :3.134      
 3rd Qu.:-1.607      3rd Qu.:2.361       3rd Qu.:3.134      
 Max.   :-1.607      Max.   :2.361       Max.   :3.134      
 conditional_mode[10] laplace_precisions[1] laplace_precisions[2]
 Min.   :-0.2836      Min.   :1.175         Min.   :1.201        
 1st Qu.:-0.2836      1st Qu.:1.175         1st Qu.:1.201        
 Median :-0.2836      Median :1.175         Median :1.201        
 Mean   :-0.2836      Mean   :1.175         Mean   :1.201        
 3rd Qu.:-0.2836      3rd Qu.:1.175         3rd Qu.:1.201        
 Max.   :-0.2836      Max.   :1.175         Max.   :1.201        
 laplace_precisions[3] laplace_precisions[4] laplace_precisions[5]
 Min.   :8.005         Min.   :8.225         Min.   :2            
 1st Qu.:8.005         1st Qu.:8.225         1st Qu.:2            
 Median :8.005         Median :8.225         Median :2            
 Mean   :8.005         Mean   :8.225         Mean   :2            
 3rd Qu.:8.005         3rd Qu.:8.225         3rd Qu.:2            
 Max.   :8.005         Max.   :8.225         Max.   :2            
 laplace_precisions[6] laplace_precisions[7] laplace_precisions[8]
 Min.   :1.175         Min.   :1.2           Min.   :11.6         
 1st Qu.:1.175         1st Qu.:1.2           1st Qu.:11.6         
 Median :1.175         Median :1.2           Median :11.6         
 Mean   :1.175         Mean   :1.2           Mean   :11.6         
 3rd Qu.:1.175         3rd Qu.:1.2           3rd Qu.:11.6         
 Max.   :1.175         Max.   :1.2           Max.   :11.6         
 laplace_precisions[9] laplace_precisions[10]      lp__       
 Min.   :23.97         Min.   :1.753          Min.   :-30806  
 1st Qu.:23.97         1st Qu.:1.753          1st Qu.:-30806  
 Median :23.97         Median :1.753          Median :-30806  
 Mean   :23.97         Mean   :1.753          Mean   :-30806  
 3rd Qu.:23.97         3rd Qu.:1.753          3rd Qu.:-30806  
 Max.   :23.97         Max.   :1.753          Max.   :-30806  

The algebraic solves decouple in this case, so the one correpsonding to x[5] is
sigma^2*(8 - 8*exp(x[5])) - x[5] = 0 which isn’t exactly pathological. (The answer from uniroot in R is -2.230104e-05, when sigma=2, which is not far from the starting point)

The most “interesting” equation you get is sigma^2*(0 - 10*exp(x[1])) - x[1] = 0. The root here when sigma=2 is -2.535393.

If we vary sigma and do all these solves the vary like this:

  • sigma=1e-3: The modes vary from -1e-5 to 1e-4
  • sigma = 0.01: Modes vary from -1e-3 to 1e-2
  • sigma =0.1: Modes vary from -0.1 to 1
  • sigma=1: Modes vary from -1.7 to 3.1
  • sigma=10: Modes vary from -5.3 to 3.2.

I don’t think any of those values are scary, but maybe someone else is seeing something I’m not.

The R code for producing that is (once you’ve loaded the data file in the first post)

sigma = 10
 for(i in 1:10) print(uniroot( function(x) 10^2*(sum(y[index==i])-sum(index==i)*exp(x))-x,c(-50,50))$root)

sigma[1]^2*(to_vector(sums)-to_vector(number_of_samples).*exp(x)) - x;

What is this the gradient of exactly?

When I try to write out the lpdf of:

for (i in 1:N) {
 y[i] ~ poisson_log(group_mean[index[i]]);
}

And then compute the gradient of the total lp, I get a vector that looks something like:

to_vector(sums) .* x - to_vector(number_of_samples) .* exp(x)

I feel like I’m wrong here though haha… It was just the only thing that stood out to me.

edit: Yup, I’m wrong

If you use optimization instead of HMC, you get a mode where sigma = 14.76, so something is definitely off.

It seems that Stan gets the same answer here, which would suggest that maybe the implicit differentiation is wrong?

But the finite diff matches, so now I have no idea.

I’m really confused. I used optimize and the results were consistent with the simple estimator of the log-mean log(observed/expected). I also checked and the algebraic solver really is finding zeros.

But NUTS still isn’t moving.

Does anyone (maybe @charlesm93?) know how Stan computes the derivative of the algebraic solve?I had a quick look in the code but couldn’t find anything. My next best guess as to why this isn’t working is that that derivative isn’t being comuted smoothly…

Edit: I found the autodiff for the algebraic solver. It is doing the thing that seems like it wouldn’t cause problems (using the analytic gradient). So again, I have no idea. Help!

Edit 2: Is there a way to get cmdstan to just output the log posterior on a grid? Because the sampler isn’t moving, I can’t just use the output. But because this is a 1D problem, plotting should be quite helpful…

Edit 3: Just because for some reason I didn’t add them before, there’s the stansummary output for the optimize model

Inference for Stan model: laplace_model
1 chains: each with iter=(1); warmup=(0); thin=(0); 1 iterations saved.

Warmup took (0.00) seconds, 0.00 seconds total
Sampling took (0.00) seconds, 0.00 seconds total

                              Mean     MCSE   StdDev       5%      50%      95%  N_Eff  N_Eff/s    R_hat
lp__                      -3.0e+04      nan      nan      nan      nan      nan    1.0      inf      nan
sigma                      1.5e+01      nan      nan      nan      nan      nan    1.0      inf      nan
sigma_tmp[1]               1.5e+01      nan      nan      nan      nan      nan    1.0      inf      nan
conditional_mode[1]       -5.9e+00      nan      nan      nan      nan      nan    1.0      inf      nan
conditional_mode[2]       -5.7e+00      nan      nan      nan      nan      nan    1.0      inf      nan
conditional_mode[3]        2.0e+00      nan      nan      nan      nan      nan    1.0      inf      nan
conditional_mode[4]        2.0e+00      nan      nan      nan      nan      nan    1.0      inf      nan
conditional_mode[5]      -2.3e-311      nan      nan      nan      nan      nan    1.0      inf      nan
conditional_mode[6]       -5.9e+00      nan      nan      nan      nan      nan    1.0      inf      nan
conditional_mode[7]       -2.2e+00      nan      nan      nan      nan      nan    1.0      inf      nan
conditional_mode[8]        2.4e+00      nan      nan      nan      nan      nan    1.0      inf      nan
conditional_mode[9]        3.2e+00      nan      nan      nan      nan      nan    1.0      inf      nan
conditional_mode[10]      -3.2e-01      nan      nan      nan      nan      nan    1.0      inf      nan
this_should_be_zero[1]    -8.0e-11      nan      nan      nan      nan      nan    1.0      inf      nan
this_should_be_zero[2]    -3.3e-12      nan      nan      nan      nan      nan    1.0      inf      nan
this_should_be_zero[3]     1.1e-12      nan      nan      nan      nan      nan    1.0      inf      nan
this_should_be_zero[4]    -1.6e-12      nan      nan      nan      nan      nan    1.0      inf      nan
this_should_be_zero[5]    2.3e-311      nan      nan      nan      nan      nan    1.0      inf      nan
this_should_be_zero[6]     8.3e-11      nan      nan      nan      nan      nan    1.0      inf      nan
this_should_be_zero[7]    -1.1e-11      nan      nan      nan      nan      nan    1.0      inf      nan
this_should_be_zero[8]    -1.9e-12      nan      nan      nan      nan      nan    1.0      inf      nan
this_should_be_zero[9]     2.6e-12      nan      nan      nan      nan      nan    1.0      inf      nan
this_should_be_zero[10]   -5.2e-13      nan      nan      nan      nan      nan    1.0      inf      nan
sigma_tmp_tmp[1]           1.5e+01      nan      nan      nan      nan      nan    1.0      inf      nan

Samples were drawn using lbfgs with .
For each parameter, N_Eff is a crude measure of effective sample size,
and R_hat is the potential scale reduction factor on split chains (at 
convergence, R_hat=1).

The values for log(observed / expected) are

-Inf  -Inf  1.9715526  2.0074680  0.0000000 -Inf -2.1972246  2.3812282  3.1675825 -0.3184537

which is consistent with the condition_mode parameters. this_should_be_zero is computed as

this_should_be_zero=conditional_grad(conditional_mode,sigma_tmp,number_of_samples,sums);

and is indeed zero (which means the algebraic solver works. Thanks Charles!)

The (slightly updated) laplace.stan is below.

functions {
  vector conditional_grad(vector x, vector sigma, real[] number_of_samples, int[] sums) {
    vector[dims(x)[1]] result;
    result = sigma[1]^2*(to_vector(sums)-to_vector(number_of_samples).*exp(x)) - x;
    return result;
  }
  vector conditional_neg_hessian(vector x,  real sigma) {
    vector[dims(x)[1]] result;
    result = exp(x) + 1/sigma^2;
    return result;
  }
}
data {
  int N;
  int M;
  int y[N];
  int<lower=1, upper=M> index[N];
}
transformed data {
  vector[M] xzero = rep_vector(0.0, M);
  real number_of_samples[M];
  int sums[M];
  for (j in 1:M) {
    sums[j] = 0;
    number_of_samples[j]=0.0;
  }
  for (i in 1:N) {
    sums[index[i]] += y[i];
    number_of_samples[index[i]] +=1.0;

    }
}
parameters {
  //vector[M] group_mean;
  real<lower=0> sigma;
}
transformed parameters {
  vector[1] sigma_tmp;
  vector[M] conditional_mode;
  sigma_tmp[1] = sigma;
  conditional_mode = algebra_solver(conditional_grad, xzero, sigma_tmp, number_of_samples, sums );
}
model {
  vector[M] laplace_precisions;
  sigma ~ gamma(2,1);
  laplace_precisions = conditional_neg_hessian(conditional_mode, sigma);
  for (i in 1:N) {
    target += -0.5*laplace_precisions[index[i]]*(y[i]- conditional_mode[index[i]])^2;
  }
}

generated quantities {
  vector[M] this_should_be_zero;
  vector[1] sigma_tmp_tmp;
  sigma_tmp_tmp[1] = sigma;
  this_should_be_zero=conditional_grad(conditional_mode,sigma_tmp,number_of_samples,sums);
}

@bgoodri What parameter values did you use to test for a match?

Here’s a link to the design of the solver’s autodiff (for the record), by @betanalpha : https://global.discourse-cdn.com/standard14/uploads/mc_stan/original/1X/61820bb07cf80c325ee4bf6b892c255b55cb3c92.pdf
I could add an entry to the user manual, if people think it would be helpful. As noted, issues can arise if the matrix inversion is not done properly – maybe the matrix is not invertible in the first place, etc. But I’d have to work out the math to see if this applies to your case. We could always print the output of the Jacobian and see if it dissolves to NaN as the markov chain moves.

I’ll take a guess that he tested with `sigma=1 because the actual gradient is

 result =to_vector(sums)-to_vector(number_of_samples).*exp(x) - x/ sigma[1]^2;

but I multiplied through by sigma[1]^2 for stability.

It doesn’t change anything (I just checked again)

1 Like

I tested a bunch of values for sigma. Anytime the algebraic solver returned a solution, I got a match between the derivative and the finite difference derivative. Like this

TESTING GRADIENT FOR MODEL 'laplace' NOW (CHAIN 1).
TEST GRADIENT MODE

 Log probability=-29896.4

 param idx           value           model     finite diff           error
         0         1.14473         218.842         218.841      0.00110442

Sounds like the next step is a (failing) unit test?

I had another look at this this morning. Still confused, but you can program in a custom Newton iteration since the Hessian is diagonal and things sample. The answer super doesn’t agree with the reference model, but the should_be_zero stuff is indeed zero. Anyway, might help.

I added a to_vector(number_of_samples) to the Hessian which I think should be there, and changed the conditional grad back to have the sigma in its original place.

Inference for Stan model: dsimp_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1,1,1); 4000 iterations saved.

Warmup took (0.3290, 0.3292, 0.3151, 0.3151) seconds, 1.288 seconds total
Sampling took (0.3355, 0.3664, 0.3313, 0.3411) seconds, 1.374 seconds total

                                Mean       MCSE     StdDev          5%         50%         95%  N_Eff  N_Eff/s      R_hat
lp__                      -1.403e+05  2.041e-02  8.704e-01  -1.403e+05  -1.403e+05  -1.403e+05   1818     1323  1.001e+00
accept_stat__              9.226e-01  1.756e-03  1.110e-01   6.897e-01   9.703e-01   1.000e+00   4000     2910  1.002e+00
stepsize__                 5.681e-01  1.650e-02  2.334e-02   5.324e-01   5.817e-01   5.946e-01  2.001    1.456  2.010e+13
treedepth__                1.368e+00  8.014e-03  4.824e-01   1.000e+00   1.000e+00   2.000e+00   3623     2636  1.001e+00
n_leapfrog__               2.474e+00  2.437e-02  1.152e+00   1.000e+00   3.000e+00   3.000e+00   2235     1626  1.013e+00
divergent__                0.000e+00  0.000e+00  0.000e+00   0.000e+00   0.000e+00   0.000e+00   4000     2910       -nan
energy__                   1.403e+05  2.555e-02  1.079e+00   1.403e+05   1.403e+05   1.403e+05   1783     1297  1.001e+00
sigma                      2.019e-01  1.163e-05  4.429e-04   2.011e-01   2.019e-01   2.026e-01   1449     1054  1.001e+00
sigma_tmp[1]               2.019e-01  1.163e-05  4.429e-04   2.011e-01   2.019e-01   2.026e-01   1449     1054  1.001e+00
conditional_mode[1]       -3.014e-01  2.670e-05  1.016e-03  -3.031e-01  -3.014e-01  -2.998e-01   1449     1054  1.001e+00
conditional_mode[2]       -2.531e-01  2.328e-05  8.863e-04  -2.545e-01  -2.531e-01  -2.516e-01   1449     1054  1.001e+00
conditional_mode[3]        1.400e+00  5.729e-05  2.181e-03   1.397e+00   1.400e+00   1.404e+00   1449     1054  1.001e+00
conditional_mode[4]        1.336e+00  6.431e-05  2.448e-03   1.332e+00   1.336e+00   1.340e+00   1449     1055  1.001e+00
conditional_mode[5]        0.000e+00  0.000e+00  0.000e+00   0.000e+00   0.000e+00   0.000e+00   4000     2910       -nan
conditional_mode[6]       -3.014e-01  2.670e-05  1.016e-03  -3.031e-01  -3.014e-01  -2.998e-01   1449     1054  1.001e+00
conditional_mode[7]       -4.068e-01  3.151e-05  1.200e-03  -4.088e-01  -4.068e-01  -4.049e-01   1449     1054  1.001e+00
conditional_mode[8]        1.888e+00  5.496e-05  2.092e-03   1.885e+00   1.888e+00   1.891e+00   1449     1054  1.001e+00
conditional_mode[9]        2.279e+00  1.014e-04  3.859e-03   2.273e+00   2.279e+00   2.285e+00   1449     1054  1.001e+00
conditional_mode[10]      -8.551e-02  6.983e-06  2.659e-04  -8.594e-02  -8.550e-02  -8.508e-02   1449     1054  1.001e+00
this_should_be_zero[1]     1.443e-17  1.252e-17  7.576e-16  -8.882e-16   0.000e+00   8.882e-16   3662     2665  1.000e+00
this_should_be_zero[2]     4.663e-18  1.105e-17  6.816e-16  -8.882e-16   0.000e+00   8.882e-16   3803     2767  9.998e-01
this_should_be_zero[3]    -2.487e-17  9.076e-17  5.431e-15  -7.105e-15   0.000e+00   7.105e-15   3580     2605  1.000e+00
this_should_be_zero[4]     1.155e-16  8.465e-17  5.142e-15  -7.105e-15   0.000e+00   7.105e-15   3689     2684  9.998e-01
this_should_be_zero[5]     0.000e+00  0.000e+00  0.000e+00   0.000e+00   0.000e+00   0.000e+00   4000     2910       -nan
this_should_be_zero[6]     1.443e-17  1.252e-17  7.576e-16  -8.882e-16   0.000e+00   8.882e-16   3662     2665  1.000e+00
this_should_be_zero[7]    -2.620e-17  1.781e-17  1.079e-15  -1.776e-15   0.000e+00   1.776e-15   3672     2672  9.997e-01
this_should_be_zero[8]    -1.030e-16  1.527e-16  8.978e-15  -1.421e-14   0.000e+00   1.421e-14   3459     2517  1.001e+00
this_should_be_zero[9]    -6.342e-16  1.514e-16  9.005e-15  -1.421e-14   0.000e+00   1.421e-14   3539     2575  9.995e-01
this_should_be_zero[10]    3.664e-18  1.176e-17  7.016e-16  -1.332e-15   0.000e+00   1.332e-15   3562     2592  9.997e-01
sigma_tmp_tmp[1]           2.019e-01  1.163e-05  4.429e-04   2.011e-01   2.019e-01   2.026e-01   1449     1054  1.001e+00

Model is:

functions {
  vector conditional_grad(vector x, vector sigma, real[] number_of_samples, int[] sums) {
    vector[dims(x)[1]] result;
    result = (to_vector(sums)-to_vector(number_of_samples).*exp(x)) - x / (sigma[1]^2);
    return result;
  }
  vector conditional_neg_hessian(real[] number_of_samples, vector x,  real sigma) {
    vector[dims(x)[1]] result;
    result = to_vector(number_of_samples) .* exp(x) + 1/sigma^2;
    return result;
  }
}
data {
  int N;
  int M;
  int y[N];
  int<lower=1, upper=M> index[N];
}
transformed data {
  vector[M] xzero = rep_vector(0.0, M);
  real number_of_samples[M];
  int sums[M];
  for (j in 1:M) {
    sums[j] = 0;
    number_of_samples[j]=0.0;
  }
  for (i in 1:N) {
    sums[index[i]] += y[i];
    number_of_samples[index[i]] +=1.0;
    
  }
}
parameters {
  //vector[M] group_mean;
  real<lower=0> sigma;
}
transformed parameters {
  vector[1] sigma_tmp;
  vector[M] conditional_mode = rep_vector(0, M);
  sigma_tmp[1] = sigma;
  {
    for(i in 1:10) {
      vector[M] f = conditional_grad(conditional_mode, sigma_tmp, number_of_samples, sums);
      vector[M] fp = -conditional_neg_hessian(number_of_samples, conditional_mode,  sigma);
      
      conditional_mode = conditional_mode - f ./ fp;
    }
  }
  //conditional_mode = algebra_solver(conditional_grad, xzero, sigma_tmp, number_of_samples, sums );
}
model {
  vector[M] laplace_precisions;
  sigma ~ gamma(2,1);
  laplace_precisions = conditional_neg_hessian(number_of_samples, conditional_mode, sigma);
  for (i in 1:N) {
    target += -0.5*laplace_precisions[index[i]]*(y[i]- conditional_mode[index[i]])^2;
  }
}

generated quantities {
  vector[M] this_should_be_zero;
  vector[1] sigma_tmp_tmp;
  sigma_tmp_tmp[1] = sigma;
  this_should_be_zero=conditional_grad(conditional_mode,sigma_tmp,number_of_samples,sums);
}

Oh, I guess the number_of_samples thing wasn’t there cause the Hessian you need to do that minimization is slightly different from the Hessian you’d use here target += -0.5*laplace_precisions[index[i]]*(y[i]- conditional_mode[index[i]])^2;?

If you go with two Hessians then you get:

sigma                      7.114e+00  1.259e-02  5.140e-01   6.299e+00   7.097e+00   7.981e+00   1668  7.001e+02  9.998e-01
sigma_tmp[1]               7.114e+00  1.259e-02  5.140e-01   6.299e+00   7.097e+00   7.981e+00   1668  7.001e+02  9.998e-01
conditional_mode[1]       -4.679e+00  2.921e-03  1.190e-01  -4.873e+00  -4.679e+00  -4.483e+00   1661  6.973e+02  9.997e-01
conditional_mode[2]       -4.496e+00  2.900e-03  1.182e-01  -4.689e+00  -4.496e+00  -4.301e+00   1661  6.974e+02  9.997e-01
conditional_mode[3]        1.971e+00  1.801e-06  7.298e-05   1.971e+00   1.971e+00   1.971e+00   1643  6.896e+02  9.996e-01
conditional_mode[4]        2.007e+00  2.160e-06  8.761e-05   2.007e+00   2.007e+00   2.007e+00   1646  6.909e+02  9.996e-01
conditional_mode[5]        2.722e-17  7.176e-19  4.151e-17  -4.374e-17   2.804e-17   9.623e-17   3346  1.405e+03  1.000e+00
conditional_mode[6]       -4.679e+00  2.921e-03  1.190e-01  -4.873e+00  -4.679e+00  -4.483e+00   1661  6.973e+02  9.997e-01
conditional_mode[7]       -2.176e+00  7.582e-05  3.076e-03  -2.180e+00  -2.176e+00  -2.170e+00   1645  6.908e+02  9.997e-01
conditional_mode[8]        2.381e+00  1.446e-06  5.851e-05   2.381e+00   2.381e+00   2.381e+00   1638  6.877e+02  9.997e-01
conditional_mode[9]        3.167e+00  2.405e-06  9.745e-05   3.167e+00   3.167e+00   3.167e+00   1642  6.893e+02  9.997e-01
conditional_mode[10]      -3.177e-01  2.850e-06  1.156e-04  -3.178e-01  -3.177e-01  -3.175e-01   1645  6.906e+02  9.997e-01

for the approx. and

group_mean[1]    -3.219e+00  1.945e-02  1.230e+00  -5.520e+00  -3.071e+00  -1.523e+00   4000     5148  9.998e-01
group_mean[2]    -3.096e+00  2.044e-02  1.293e+00  -5.557e+00  -2.894e+00  -1.331e+00   4000     5148  1.002e+00
group_mean[3]     1.959e+00  1.808e-03  1.143e-01   1.761e+00   1.963e+00   2.142e+00   4000     5148  9.995e-01
group_mean[4]     1.992e+00  1.958e-03  1.238e-01   1.781e+00   1.996e+00   2.190e+00   4000     5148  9.999e-01
group_mean[5]    -5.891e-02  5.709e-03  3.611e-01  -6.883e-01  -3.878e-02   5.085e-01   4000     5148  9.995e-01
group_mean[6]    -3.265e+00  2.732e-02  1.343e+00  -5.756e+00  -3.041e+00  -1.467e+00   2418     3112  1.002e+00
group_mean[7]    -2.162e+00  1.043e-02  6.595e-01  -3.325e+00  -2.116e+00  -1.191e+00   4000     5148  1.001e+00
group_mean[8]     2.372e+00  1.501e-03  9.495e-02   2.214e+00   2.373e+00   2.529e+00   4000     5148  1.000e+00
group_mean[9]     3.154e+00  1.630e-03  1.031e-01   2.979e+00   3.155e+00   3.323e+00   4000     5148  9.994e-01
group_mean[10]   -3.717e-01  5.598e-03  3.540e-01  -9.905e-01  -3.462e-01   1.801e-01   4000     5148  9.998e-01
sigma             2.202e+00  6.813e-03  4.309e-01   1.571e+00   2.154e+00   2.959e+00   4000     5148  1.002e+00

For the actual model. Sigma seems screwed up but the conditional modes don’t look too bad.