Hi all,
I am looking for guidance on how to increase the speed and efficiency of a reduced-rank GP in my model. In particular, I am surprised by the number of basis functions I seem to need.
Data
The data consist of monthly birth counts in Switzerland over 25 years, stratified by maternal age (15–50). This results in 25 × 12 × 36 ≈ 10,800 observations.
Model
I use a negative binomial distribution, parameterized via a mean \mu and an overdispersion parameter. The log mean number of births in each (time, age) cell is modeled as:
with N_{a,t} being the female population size of age a at time t. The function f(a,\alpha,\beta,\lambda) is the probability of giving birth according to mother age a and over time t (time is month ID, from 1 to 360). It is modeled using a Gaussian curve over age, with three parameters:
-
peak age (time-varying) \alpha(t)
-
peak height (fixed) \beta
-
curve width (fixed) \lambda
The peak age is allowed to vary smoothly over time and is parameterized through a Gaussian process (using exponential transformation to restrict to positive values:
where the GP is a squared-exponential Gaussian process over time.
The GP is implemented using a practical Hilbert-space approximation on a bounded domain (sine basis with truncated spectral density), following Solin & Särkkä and the Birthdays case study.
Issue
At monthly resolution, I need about M = 50 basis functions for stable sampling. For smaller M, I observe:
-
high R-hat for the GP lengthscale (called
lambda_yearin the script below) -
either max-treedepth warnings or a few divergences when tightening the prior
With M = 50, sampling is stable but slow (≈400 seconds).
By contrast, using a yearly GP works well and fast with M = 10.
Questions
-
Is it expected that a monthly GP over ~25 years requires M ≈ 50 under a Hilbert-space approximation, or does this suggest a suboptimal parameterization?
-
Would it be reasonable to fix or strongly constrain the GP lengthscale given the truncated spectral support?
-
Could the multiplicative structure
exp(GP)be driving these issues, and is it reasonable to normalize the GP by a fixed factor to keep the exponent in a reasonable range?
Here are the estimates with M=50 basis
The lengthscale lambda_year is quite high because the change over peak age over time is quite linear:
Full model code below.
Thanks in advance for any insight.
functions {
//see https://github.com/avehtari/casestudies/blob/master/Birthdays
// basis function (exponentiated quadratic kernel)
matrix PHI_EQ(int N, int M, real L, vector x) {
matrix[N,M] A = rep_matrix(pi()/(2*L) * (x+L), M);
vector[M] B = linspaced_vector(M, 1, M);
matrix[N,M] PHI = sin(diag_post_multiply(A, B))/sqrt(L);
for (m in 1:M) PHI[,m] = PHI[,m] - mean(PHI[,m]); // scale to have mean 0
return PHI;
}
// spectral density (exponentiated quadratic kernel)
vector diagSPD_EQ(real alpha, real lambda, real L, int M) {
vector[M] B = linspaced_vector(M, 1, M);
return sqrt( alpha^2 * sqrt(2*pi()) * lambda * exp(-0.5*(lambda*pi()/(2*L))^2*B^2) );
}
vector log_birth_prob(vector age_id, vector a_peak, real log_h_peak, real sigma) {
return log_h_peak - (age_id - a_peak)^2 / (2*sigma^2);
}
real log_birth_prob2(real age_id, real a_peak,real log_h_peak, real sigma) {
return log_h_peak - (age_id - a_peak)^2 / (2*sigma^2);
}
}
data {
int<lower=1> N;
int<lower=1> N_year;//used for scale GP
int<lower=1> N_month;
int<lower=1> N_age;
//vector[N] n_birth;
array[N] int n_birth;
vector[N] n_pop;
array[N] int month_id;
array[N] int age_id;
// Time points and corresponding locations
vector[N_month] x;
// GP basis function settings
real<lower=0> c_year; // Boundary scaling factor
int M_year; // Number of EQ basis functions
// Hyperprior parameters
array[2] real p_age_peak1;
array[2] real p_log_h_peak1;
array[2] real p_birth_prob_sigma1;
array[2] real p_alpha_year;
array[2] real p_lambda_year;
real p_inv_sigma;
int inference;
}
transformed data {
// normalize data
real x_mean = mean(x);
real x_sd = sd(x);
vector[N_month] xn = (x - x_mean)/x_sd;
// compute boundary value
real L_year = c_year*max(xn);
// compute basis functions for f
matrix[N_month,M_year] PHI_year = PHI_EQ(N_month, M_year, L_year, xn);
}
parameters {
real <lower=0> inv_sigma;
real <lower=0> age_peak1;
real log_h_peak1;
real <lower=0> birth_prob_sigma1;
// GPs
real <lower=0> alpha_year; // Yearly GP scale by age
real <lower=0> lambda_year; // Yearly GP lengthscale by age
vector[M_year] beta_year; // Basis coefficients for yearly GP
}
transformed parameters {
vector[M_year] diagSPD_year;
vector[N_month] f_year;
// compute spectral densities for f
diagSPD_year = diagSPD_EQ(alpha_year, lambda_year, L_year, M_year);
// compute f
f_year = PHI_year * (diagSPD_year .* beta_year);
real sigma = inv(inv_sigma);
}
model {
inv_sigma ~ exponential(p_inv_sigma);
// weak priors
age_peak1 ~ normal(p_age_peak1[1],p_age_peak1[2]);
log_h_peak1 ~ normal(p_log_h_peak1[1],p_log_h_peak1[2]);
birth_prob_sigma1 ~ normal(p_birth_prob_sigma1[1],p_birth_prob_sigma1[2]);
//GP: variance and lengthscale
lambda_year ~ lognormal(p_lambda_year[1], p_lambda_year[2]);
alpha_year ~ normal(p_alpha_year[1], p_alpha_year[2]);
beta_year ~ normal(0, 1);
if(inference==1){
target += neg_binomial_2_log_lpmf(n_birth |log_birth_prob(to_vector(age_id), age_peak1 * exp(f_year[month_id]/N_year),//scale by N_month might decrease the risk of divergences
log_h_peak1,
birth_prob_sigma1) + log(n_pop), sigma);
}
}
generated quantities{
array[N_month, N_age] real birth_prob;
for(i in 1:N_month){
for(j in 1:N_age){
birth_prob[i,j] = exp(log_birth_prob2(j, age_peak1 * exp(f_year[i]/N_year),
log_h_peak1,
birth_prob_sigma1));
}
}
}


