Approximate Gaussian process: increase speed and efficiency

Hi all,

I am looking for guidance on how to increase the speed and efficiency of a reduced-rank GP in my model. In particular, I am surprised by the number of basis functions I seem to need.

Data

The data consist of monthly birth counts in Switzerland over 25 years, stratified by maternal age (15–50). This results in 25 × 12 × 36 ≈ 10,800 observations.

Model

I use a negative binomial distribution, parameterized via a mean \mu and an overdispersion parameter. The log mean number of births in each (time, age) cell is modeled as:

\log\mu_{t,a} = \log f(a,\alpha(t),\beta,\lambda) + \log N_{t,a}

with N_{a,t} being the female population size of age a at time t. The function f(a,\alpha,\beta,\lambda) is the probability of giving birth according to mother age a and over time t (time is month ID, from 1 to 360). It is modeled using a Gaussian curve over age, with three parameters:

  • peak age (time-varying) \alpha(t)

  • peak height (fixed) \beta

  • curve width (fixed) \lambda

The peak age is allowed to vary smoothly over time and is parameterized through a Gaussian process (using exponential transformation to restrict to positive values:

\alpha(t) = \alpha_0 \exp( GP(t))

where the GP is a squared-exponential Gaussian process over time.

The GP is implemented using a practical Hilbert-space approximation on a bounded domain (sine basis with truncated spectral density), following Solin & Särkkä and the Birthdays case study.

Issue

At monthly resolution, I need about M = 50 basis functions for stable sampling. For smaller M, I observe:

  • high R-hat for the GP lengthscale (calledlambda_yearin the script below)

  • either max-treedepth warnings or a few divergences when tightening the prior

With M = 50, sampling is stable but slow (≈400 seconds).
By contrast, using a yearly GP works well and fast with M = 10.

Questions

  1. Is it expected that a monthly GP over ~25 years requires M ≈ 50 under a Hilbert-space approximation, or does this suggest a suboptimal parameterization?

  2. Would it be reasonable to fix or strongly constrain the GP lengthscale given the truncated spectral support?

  3. Could the multiplicative structure exp(GP) be driving these issues, and is it reasonable to normalize the GP by a fixed factor to keep the exponent in a reasonable range?

Here are the estimates with M=50 basis

The lengthscale lambda_year is quite high because the change over peak age over time is quite linear:

Full model code below.

Thanks in advance for any insight.

functions {
  //see https://github.com/avehtari/casestudies/blob/master/Birthdays
  // basis function (exponentiated quadratic kernel)
    matrix PHI_EQ(int N, int M, real L, vector x) {
      matrix[N,M] A = rep_matrix(pi()/(2*L) * (x+L), M);
      vector[M] B = linspaced_vector(M, 1, M);
      matrix[N,M] PHI = sin(diag_post_multiply(A, B))/sqrt(L);
      for (m in 1:M) PHI[,m] = PHI[,m] - mean(PHI[,m]); // scale to have mean 0
      return PHI;
    }
  // spectral density (exponentiated quadratic kernel)
  vector diagSPD_EQ(real alpha, real lambda, real L, int M) {
    vector[M] B = linspaced_vector(M, 1, M);
    return sqrt( alpha^2 * sqrt(2*pi()) * lambda * exp(-0.5*(lambda*pi()/(2*L))^2*B^2) );
  }
  
  vector log_birth_prob(vector age_id, vector a_peak, real log_h_peak, real sigma) {
    return log_h_peak - (age_id - a_peak)^2 / (2*sigma^2);
  }
  
  real log_birth_prob2(real age_id, real a_peak,real log_h_peak, real sigma) {
    return log_h_peak - (age_id - a_peak)^2 / (2*sigma^2);
  }
}

data {
  int<lower=1> N;
  int<lower=1> N_year;//used for scale GP
  int<lower=1> N_month;
  int<lower=1> N_age;
  
  //vector[N] n_birth;
  array[N] int n_birth;
  vector[N] n_pop;
  
  array[N] int month_id;
  array[N] int age_id;
  
  // Time points and corresponding locations
  vector[N_month] x;
  
  // GP basis function settings
  real<lower=0> c_year;           // Boundary scaling factor
  int M_year;                     // Number of EQ basis functions
  
  // Hyperprior parameters
  array[2] real p_age_peak1;
  array[2] real p_log_h_peak1;
  array[2] real p_birth_prob_sigma1;
  
  array[2] real p_alpha_year;
  array[2] real p_lambda_year;
  
  real p_inv_sigma;
  
  int inference;
}

transformed data {
  // normalize data
  real x_mean = mean(x);
  real x_sd = sd(x);
  vector[N_month] xn = (x - x_mean)/x_sd;
  
  // compute boundary value
  real L_year = c_year*max(xn);
  
  // compute basis functions for f
  matrix[N_month,M_year] PHI_year = PHI_EQ(N_month, M_year, L_year, xn);
}

parameters {
  real <lower=0> inv_sigma;
  
  real <lower=0> age_peak1;
  real log_h_peak1;
  real <lower=0> birth_prob_sigma1;
  
  // GPs
  real <lower=0> alpha_year;       // Yearly GP scale by age
  real <lower=0> lambda_year;      // Yearly GP lengthscale by age
  vector[M_year] beta_year; // Basis coefficients for yearly GP
}
transformed parameters {
  vector[M_year] diagSPD_year;
  vector[N_month] f_year;
  
  // compute spectral densities for f
  diagSPD_year = diagSPD_EQ(alpha_year, lambda_year, L_year, M_year);
  // compute f
  f_year = PHI_year * (diagSPD_year .* beta_year);
  
  real sigma = inv(inv_sigma);
}
model {
  inv_sigma ~ exponential(p_inv_sigma);
  
  // weak priors
  age_peak1 ~ normal(p_age_peak1[1],p_age_peak1[2]);
  log_h_peak1 ~ normal(p_log_h_peak1[1],p_log_h_peak1[2]);
  birth_prob_sigma1 ~ normal(p_birth_prob_sigma1[1],p_birth_prob_sigma1[2]);
  
  //GP: variance and lengthscale
  lambda_year ~ lognormal(p_lambda_year[1], p_lambda_year[2]);
  alpha_year ~ normal(p_alpha_year[1], p_alpha_year[2]);
  beta_year ~ normal(0, 1);
  
  if(inference==1){
    target += neg_binomial_2_log_lpmf(n_birth |log_birth_prob(to_vector(age_id), age_peak1 * exp(f_year[month_id]/N_year),//scale by N_month might decrease the risk of divergences
                                                              log_h_peak1,
                                                              birth_prob_sigma1) + log(n_pop), sigma);
  }
}

generated quantities{
  array[N_month, N_age] real birth_prob;
  
  for(i in 1:N_month){
    for(j in 1:N_age){
      birth_prob[i,j] = exp(log_birth_prob2(j, age_peak1 * exp(f_year[i]/N_year),
                                            log_h_peak1,
                                            birth_prob_sigma1));
    }
  }
}

Thanks for posting. This is way beyond my own GP expertise, so I think we need @avehtari or @andrewgelman here—they’re the birthday problem experts :-).

This is the second time someone’s mentioned Solin and Särkkä’s approximation to me in the last two weeks, so I put it on my stack of things to try to understand. If you’re only working in 1D and are OK with just a point estimate of a mean, there are some super-scalable spectral methods for GPs which are not approximate (e.g., this paper by some of my colleagues here at Flatiron, Uniform approximation of common Gaussian process kernels using equispaced Fourier grids - ScienceDirect). You can stretch them to do variance, but not to do sampling as far as I understand.

To make a guess: it’s possible your prior on length-scale is permitting the sampler to explore low length-scale values, and at these low values M = 50 is required to produce robust results. I would guess that the likelihood space is poorly defined at these low length-scales if M is set too low.

You could try using the tighter inverse-gamma priors described in Robust Gaussian Processes in Stan to restrict the length-scale to higher values.

Here’s R code to derive those priors if helpful:

find_inv_gamma_params <- function(l, u, tail_prob = 0.01) {
  f <- function(x) {
    alpha <- exp(x[1])
    theta <- exp(x[2])
    
    cdf_l <- pgamma(1/l, shape = alpha, rate = theta, lower.tail = FALSE)
    cdf_u <- pgamma(1/u, shape = alpha, rate = theta, lower.tail = FALSE)
    
    y1 <- cdf_l - tail_prob
    y2 <- cdf_u - (1.0 - tail_prob)
    
    return(c(y1, y2))
  }
  
  initial_guess <- c(log(5.0), log(5.0))
  
  require(nleqslv)
  result <- nleqslv(initial_guess, f)
  
  if (result$termcd %in% c(1, 2)) {
    sol <- exp(result$x)
    print(str_c("inv_gamma(", round(sol[1], digits = 3), ", ", round(sol[2], digits = 3), ")"))
    return(list(alpha = sol[1], theta = sol[2]))
  } else {
    stop(sprintf("Solver failed to converge. Return code: %d, Message: %s", 
                 result$termcd, result$message))
  }
}

And I believe the max tree depth warnings are generally not an issue - I have run into them a lot when using this approximation and results seem OK.

For 10,800 observations and considering that the posterior is far from normal 400s sounds reasonable. If there is so big difference between the needed basis functions given yearly or monthly data, maybe there is seasonal pattern and you could use an additive model with slowly changing component and more wiggly periodic component as in Birthdays case study, which would reduve the number of needed basis functions.

exp(GP) makes the posterior more challenging, but it’s not the main reason.

HSGP is likely to have such a posterior that it would benefit from a partial non-centered parameterization (or you can observe both max treedepth exceedences and divergences). I hope @Niko will soon post a blog post about that.

Some of these could be used with Stan, too, and they can be faster than HSGP, but they have some limitations making them less general, which is the reason I have advocated HSGP

There are many different basis function approximations with about similar performance. Solin and Särkkä did show that HSGP was best for given number of basis functions among some set of approximations, but the differences can be negligible and there can be other approximations with even smaller difference in performance. I decided to use HSGP as I know Solin and Särkkä well and did get some useful advice from Solin. Anyway, how to go from priors on function space to basis functions and priors on coefficients is cool thing to learn!