Nested Laplace approximation roadmap

#1

Tagging @avehtari, @Daniel_Simpson, @seantalts, @betanalpha.

Latent gaussian models

Suppose we have a latent Gaussian model
\phi \sim \pi(\phi) \\ \theta \sim \mathrm{Normal}(\mu(\phi), \Sigma(\phi)) \\ y_{j \in g(i)} \sim \pi(y | \theta_i, \phi)

When doing a nested Laplace approximation to marginalize out \theta, we can show that \log p(\phi | y) \approx log p(\phi) + \alpha, where:

\alpha = \log p(y | \theta^*, \phi) - \frac{1}{2} \left ( \log \mathrm{det}|\Sigma_\phi| + \log \mathrm{det}|H| + [\theta^* - \mu(\phi)]^T \Sigma^{-1}_\phi [\theta^* - \mu(\phi)] \right)

where \theta^* is the mode of p(\theta | y, \phi), and H = \nabla^2 \log p(\theta | y, \phi), with the Hessian taken with respect to \theta.

Function specification

The mode is found using an algebraic solver, and the Hessian is analytically known for certain observational models, notably poisson, normal, binomial, and negative-binomial. For these cases, we can write a Stan function that automates the calculation of \alpha. A Stan model might then look as:

functions {
  // functor which return individual elements of vectors / matrix
  vector mu_phi(real phi, real y, int i) { }
  matrix sigma_phi(real phi, real y, int i, int j) { }
}

data {
  int M;  // number of groups
  int N;  // number of observation
  int/real y[N];
  int index[N];  // group for each observation
}

parameters {
  real phi;
}

transformed parameters {
  // vector mu_phi = ...;
  // matrix sigma_phi = ...;
  static vector theta_0[M] = to_vector(rep_array(0, M));
}

model {
  phi ~ p(.);

  target += laplace_lgp_poisson(theta_0, mu_phi, sigma_phi, phi, y, index);
  // or
  target += laplace_lgp_binomial(theta_0, mu_phi, sigma_phi, phi, y, index);
}

Remark 1: the laplace functions take functor as an argument, so it’s a bit of pain to implement them. It might be worth waiting for Stan 3 and use lambda functions.

Remark 2: sigma_phi and mu_phi are functors that compute the individual elements of a matrix. This allows more localization, improves memory, and generates smaller autodiff trees. See the function spec for Gaussian processes.

Remark 3: in many cases, \Sigma_\phi is sparse.

Finding the curvature

We would like to generalize to a broader class of likelihoods. The issue is computing H = \nabla^2 \log p(y | \theta, \phi) - \Sigma_\phi. Automatically computing the Hessian of an arbitrary function can be costly, though it is feasible. In practice (according to Dan), modelers work out the Hessian and provide it. We could allow users to specify the Hessian, and pass it to laplace_lgp_general.

Finding the mode

Finding the mode is a high-dimensional root-finding problem. Since the problem is (approximately) convex, a candidate approach is to use a Newton solver, in particular the KINSOL implementation, see earlier post. In particular, we want to consider parallelizing the solver, which can be done separately from calculating the autodiff.

We also want to use an initial guess \theta_0 (or theta_0 in the code), based on solutions in previous iterations, to gain speedup. Sean and I discussed this a little bit. Note this is applicable to any problem that uses the solver. Basically, theta_0 is declared as a static variable, and persists across iterations. This could be a reasonable scheme if the leapfrog steps are small enough when simulating Hamiltonian trajectories.

static vector theta_0 = 0;
theta_0 = algebra_solver(system, theta_0, parm, x_r, x_i);

See link. To safeguard against misuses, we could only accept static variables as the argument for certain functions (e.g. initial guesses).

5 Likes
#2

What if there’s more than one system being solved in the same model? For instance, let’s say we need to do a solve per patient in a pharma model. How are you going to associate this variable to the proper problem to be solved?

I don’t know of a way to enforce that in C++.

#3

What if there’s more than one system being solved in the same model?

Presumably, you use a different initial guess for each system. So I might declare a static matrix with each row corresponding to a patient.

I don’t know of a way to enforce that in C++.

Not sure either, I need to think about this.

#4

Is the idea that this static thing is part of the Stan language? Otherwise, I don’t see how to roll this into the C++ API without something like an open-ended map.