Hilbert-space GP: Rhat inflation and age-specific bias when moving from yearly to monthly birth data

Aim:

I am modelling the number of births y_{i,m,j} by year i (2000–2024), month m and mother age j (15-50). The goal is to estimate the mean number of births by age without systematic bias. The model should also capture the age shift over time (women having children at older ages). This post relates to this post, which displayed another, simpler, version of the model.

Model:

Since I encountered issues with the full model (described below), I first built a simpler baseline model and then plan to gradually add layers of complexity to reach the full model, in order to identify the source of the problem.

  • Model M0 (baseline): The model represents the yearly number of births by age y_{i,j} through a negative-binomial distribution:
y_{i,j}∼NegBin2(\mu_{i,j},\phi)

where \log \mu_{i,j} = \delta_0 + f(j) + \log P_{i,j}, with P_{i,j} the population and where f is modelled as a Hilbert-space GP over age j

  • Model M1: This is the same model as M0 but using monthly level birth data:
    y_{i,j}∼NegBin2(\mu_{i,j},\phi)

Note, however, that the model assumes the same \mu for all months of a given year and age.

  • Model M2: I use the same likelihood as in M1 but add another GP g to represent the shift in mother age over time:
\log \mu_{i,j} = \delta_0 + f(j + g(i)) + \log P_{i,j}.

Issues

Model M0 works perfectly: no divergence, Rhat<1.02 and the model is able to capture the observed number of births for each age j 2(i.e., no bias, defined as \text{bias}(j) = \sum_{i} \mu_{i,j} -y_{i,j}):

For models M1 and M2, similar convergence issues were observed. Several parameters show \hat{R} > 1.05 (1.10–1.20 in M1 and up to 3 in M2), mainly affecting the \beta coefficients and occasionally the length-scale and marginal standard deviation parameters. In M1, chain-specific likelihoods and goodness-of-fit metrics are nearly identical despite elevated \hat{R}. This suggests weak identifiability of the \beta parameters, resulting in a weakly constrained posterior and slow mixing across chains rather than distinct modes (right?)

In M2, convergence problems are more severe: some chains reach substantially lower likelihood values, indicating inadequate exploration and potentially stuck chains. Regarding bias, estimates are no longer centered at zero and appear to vary with maternal age. For model M1, this gives:

Tentative explanation

The elevated \hat{R} values in the full model M2 do not appear to be solely driven by the interaction between the two GP components, as initially suspected. Instead, the problem already emerges when switching from yearly to monthly data, even before introducing the second GP.

This is surprising because the modification from M0 to M1 is minimal and confined to the likelihood: splitting one yearly observation into12 monthly observations. Why does this seemingly small change substantially alter the posterior geometry, leading to both slow mixing and biased estimates of number of births by maternal age?

To investigate this, I have:

  • relaxed the homoscedasticity assumption by allowing one \phi per maternal age,
  • varied the number of sine basis functions,
  • imposed more informative priors on the length-scale,
  • and fixed the length-scale parameter.

Many thanks in advance for your help.

Code below for Model M0/M1:

functions {
  //see https://github.com/avehtari/casestudies/blob/master/Birthdays
  // basis function (exponentiated quadratic kernel)
    matrix PHI_EQ(int N, int M, real L, vector x) {
      matrix[N,M] A = rep_matrix(pi()/(2*L) * (x+L), M);
      vector[M] B = linspaced_vector(M, 1, M);
      matrix[N,M] PHI = sin(diag_post_multiply(A, B))/sqrt(L);
      for (m in 1:M) PHI[,m] = PHI[,m] - mean(PHI[,m]); // scale to have mean 0
      return PHI;
    }
  // spectral density (exponentiated quadratic kernel)
  vector diagSPD_EQ(real alpha, real lambda, real L, int M) {
    vector[M] B = linspaced_vector(M, 1, M);
    return sqrt( alpha^2 * sqrt(2*pi()) * lambda * exp(-0.5*(lambda*pi()/(2*L))^2*B^2) );
  }
}

data {
  int<lower=1> N;
  int<lower=1> N_year1;
  int<lower=1> N_age;
  int<lower=1> N_sigma;
  
  //vector[N] n_birth;
  array[N] int n_birth;
  vector[N] n_pop;
  
  array[N] int year_id1;
  array[N] int age_id;
  array[N] int sigma_id;
  
  // Time points and corresponding locations
  vector[N_age] x1;
  
  // GP basis function settings
  real<lower=0> c_age;           // Boundary scaling factor
  int M_age;
  
  //prior 
  real p_inv_sigma;
  array[2] real p_delta0;
  array[2] real p_alpha;
  array[2] real p_rho;

  int inference;
}

transformed data {
  // normalize data
  real x1_mean = mean(x1);
  real x1_sd = sd(x1);
  vector[N_age] xn1 = (x1 - x1_mean)/x1_sd;
  
  // compute boundary value
  real L_age = c_age*max(xn1);
  
  // compute basis functions for f_year
  matrix[N_age,M_age] PHI_age = PHI_EQ(N_age, M_age, L_age, xn1);
}

parameters {
  vector <lower=0>[N_sigma] inv_sigma;
  real delta0;
  
  // GPs
  real<lower=0> alpha;
  real<lower=0> rho;
  vector[M_age] beta_age;
}
transformed parameters {
  vector[M_age] diagSPD_age;
  vector[N_age] f_age;
  
  // compute basis functions and spectral density for f_age
  diagSPD_age = diagSPD_EQ(alpha, rho, L_age, M_age);
  f_age = PHI_age * (diagSPD_age .* beta_age);

  //sigma
  vector[N_sigma] sigma = inv(inv_sigma);
}
model {
  //priors
  inv_sigma ~ exponential(p_inv_sigma);
  delta0 ~normal(p_delta0[1], p_delta0[2]);
  
 //GP: variance and lengthscale
  alpha ~ normal(p_alpha[1], p_alpha[2]);
  rho ~ normal(p_rho[1], p_rho[2]);
  beta_age ~ std_normal();
  
  if(inference==1){
    target += neg_binomial_2_log_lpmf(n_birth | delta0 + f_age[age_id] + log(n_pop), sigma[sigma_id]);
    
  }
}

generated quantities{
  array[N_year1, N_age] real logit_birth_prob;
  array[N_year1, N_age] real birth_prob;
  array[N] int n_birth_pred;
  array[N] int n_birth_pois_pred;
  {
    for(i in 1:N_year1){
      for(j in 1:N_age){
         logit_birth_prob[i,j] =  f_age[j] + delta0;
         birth_prob[i,j] =  exp(f_age[j] + delta0);
      }
    }
    vector[N] log_mu;
    for(i in 1:N){
      log_mu[i] = f_age[age_id[i]] + delta0 + log(n_pop[i]);
    }
    n_birth_pred = neg_binomial_2_log_rng(log_mu, sigma[sigma_id]);
    n_birth_pois_pred = poisson_log_rng(log_mu);
  }

  vector[N_age] age_bias;
  for(i in 1:N_age){
    age_bias[i] = 0.0;
  }
  for(i in 1:N){
    age_bias[age_id[i]] = age_bias[age_id[i]] + (n_birth_pred[i]-n_birth[i]);
  }
}

Can you provide also data? Simulated data are fine

Hi @avehtari, thank you, you can find below a reproducible example:

  • Data: csv file can be loaded from this link.
  • Stan file
  • Rfile (you might need to change names/paths of both stan and csv files according to how you named them and where you saved them).

Stan model M0/M1.

functions {
  //see https://github.com/avehtari/casestudies/blob/master/Birthdays
  // basis function (exponentiated quadratic kernel)
    matrix PHI_EQ(int N, int M, real L, vector x) {
      matrix[N,M] A = rep_matrix(pi()/(2*L) * (x+L), M);
      vector[M] B = linspaced_vector(M, 1, M);
      matrix[N,M] PHI = sin(diag_post_multiply(A, B))/sqrt(L);
      for (m in 1:M) PHI[,m] = PHI[,m] - mean(PHI[,m]); // scale to have mean 0
      return PHI;
    }
  // spectral density (exponentiated quadratic kernel)
  vector diagSPD_EQ(real alpha, real lambda, real L, int M) {
    vector[M] B = linspaced_vector(M, 1, M);
    return sqrt( alpha^2 * sqrt(2*pi()) * lambda * exp(-0.5*(lambda*pi()/(2*L))^2*B^2) );
  }
}

data {
  int<lower=1> N;
  int<lower=1> N_year1;
  int<lower=1> N_age;
  int<lower=1> N_sigma;
  
  //vector[N] n_birth;
  array[N] int n_birth;
  vector[N] n_pop;
  
  array[N] int year_id1;
  array[N] int age_id;
  array[N] int sigma_id;
  
  // Time points and corresponding locations
  vector[N_age] x1;
  
  // GP basis function settings
  real<lower=0> c_age;           // Boundary scaling factor
  int M_age;
  
  //prior 
  real p_inv_sigma;
  array[2] real p_delta0;
  array[2] real p_alpha;
  array[2] real p_rho;

  int inference;
}

transformed data {
  // normalize data
  real x1_mean = mean(x1);
  real x1_sd = sd(x1);
  vector[N_age] xn1 = (x1 - x1_mean)/x1_sd;
  
  // compute boundary value
  real L_age = c_age*max(xn1);
  
  // compute basis functions for f_year
  matrix[N_age,M_age] PHI_age = PHI_EQ(N_age, M_age, L_age, xn1);
}

parameters {
  vector <lower=0>[N_sigma] inv_sigma;
  real delta0;
  
  // GPs
  real<lower=0> alpha;
  real<lower=0> rho;
  vector[M_age] beta_age;
}
transformed parameters {
  vector[M_age] diagSPD_age;
  vector[N_age] f_age;
  
  // compute basis functions and spectral density for f_age
  diagSPD_age = diagSPD_EQ(alpha, rho, L_age, M_age);
  f_age = PHI_age * (diagSPD_age .* beta_age);

  //sigma
  vector[N_sigma] sigma = inv(inv_sigma);
}
model {
  //priors
  inv_sigma ~ exponential(p_inv_sigma);
  delta0 ~normal(p_delta0[1], p_delta0[2]);
  
 //GP: variance and lengthscale
  alpha ~ normal(p_alpha[1], p_alpha[2]);
  rho ~ normal(p_rho[1], p_rho[2]);
  beta_age ~ std_normal();
  
  if(inference==1){
    target += neg_binomial_2_log_lpmf(n_birth | delta0 + f_age[age_id] + log(n_pop), sigma[sigma_id]);
    
  }
}

generated quantities{
  array[N_year1, N_age] real logit_birth_prob;
  array[N_year1, N_age] real birth_prob;
  array[N] int n_birth_pred;
  array[N] int n_birth_pois_pred;
  {
    for(i in 1:N_year1){
      for(j in 1:N_age){
         logit_birth_prob[i,j] =  f_age[j] + delta0;
         birth_prob[i,j] =  exp(f_age[j] + delta0);
      }
    }
    vector[N] log_mu;
    for(i in 1:N){
      log_mu[i] = f_age[age_id[i]] + delta0 + log(n_pop[i]);
    }
    n_birth_pred = neg_binomial_2_log_rng(log_mu, sigma[sigma_id]);
    n_birth_pois_pred = poisson_log_rng(log_mu);
  }

  vector[N_age] age_bias;
  for(i in 1:N_age){
    age_bias[i] = 0.0;
  }
  for(i in 1:N){
    age_bias[age_id[i]] = age_bias[age_id[i]] + (n_birth_pred[i]-n_birth[i]);
  }
}


R file

library(cmdstanr)
library(dplyr)
library(ggplot2)
set_cmdstan_path("C:/TEMP/.cmdstan/cmdstan-2.36.0")
cmdstan_path()

#load data
#load the csv file from the link: "https://drive.google.com/file/d/1f7svqwIgCzqYGZPAhCWvolWw4WCYBUnu/view?usp=sharing"
stan_df = readRDS("data/simulated_stan_df_link.RDS")


#choose the stratification level:
#by_year: one observation by year, Model M0, working
#!by_year: 12 observations by year (monthly), Model M1, Rhat>1.1 and some bias (see below)
by_year=FALSE

if(by_year){
  stan_df = stan_df %>% 
    group_by(age_id,year_id1) %>% 
    dplyr::summarise(n_birth = sum(n_birth),
                     n_pop = n_pop[1],.groups="drop")
}
stan_df

#Stan list--------------------------------------------------------------------
stan_data = list(N = dim(stan_df)[1],
                 N_year1 = length(unique(stan_df$year_id1)),
                 N_age = length(unique(stan_df$age_id)),
                 N_sigma =1,
                 
                 year_id1 = stan_df$year_id1,
                 age_id = stan_df$age_id,
                 sigma_id = rep(1,length(stan_df$age_id)),
                 
                 n_pop = stan_df$n_pop,
                 n_birth = stan_df$n_birth,
                 
                 x1 = 1:max(stan_df$age_id),
                 
                 M_age = 25, 
                 c_age = 5,
                 
                 p_rho = c(2,5),
                 p_alpha = c(3,1),
                 p_inv_sigma = 100,
                 p_delta0 = c(-5-log(12),2),
                 
                 inference = 1)
  
#Stan model-------------------------------------------------------------------
mod <- cmdstan_model("stan/mod_stan_discourse1.stan")
fit <- mod$sample(data = stan_data,
                  init=0, #this needs to be relaxed later on
                  chains = 4,
                  parallel_chains = 4,
                  iter_sampling = 200,
                  iter_warmup = 200, #tried 500 but didn't improve much
                  adapt_delta = 0.8,
                  refresh = 10,
                  seed = 1)

#high rhat for model M1 (up to 1.18)
fit$summary() %>% arrange(-rhat) %>% .[,1:10]

#bias: bias for model M1 in some ages
age_bias_df = fit$summary(variables = c("age_bias"), "mean",~quantile(.x, probs = c(0.025, 0.975))) %>%
  tidyr::extract(variable,into=c("variable","age_id"),
                 regex =paste0('(\\w.*)\\[',paste(rep("(.*)",1),collapse='\\,'),'\\]'), remove = T) %>%
  as_tibble() %>%
  dplyr::select(age_id,est=mean,lwb=`2.5%`,upb=`97.5%`) %>%
  dplyr::mutate(age_id=as.numeric(age_id))

age_bias_df %>% 
  ggplot(aes(x=age_id))+
  geom_ribbon(aes(ymin=lwb,ymax=upb),alpha=0.2,fill="darkred")+
  geom_line(aes(y=est),col="darkred")

Already with model M0 we see high dependencies

As M0 and M1 have the same parameters, in case of Poisson data model, M0 and M1 would be equivalent as sum of two Poisson distributed variables is also Poisson distributed. In case of negative binomial model data model, M0 and M1 have different likelihoods, as sum of two negbin distributed variables is not negbin distributed. Although the estimated overdispersion is not high, the change in the likelihood here causes the posterior to be even more difficult. The estimated overdispersion is smaller in M1, which would make likelihood more informative.

As hinted by the M0 posterior scatter plot and as discussed in Birthdays case study, intercept delta0 and beta_age[1] are not well identified separately (more about this in soon to be published Bayesian workflow book) causing very high dependency, and the due to more informative likelihood in model M1, the dependency is higher in M1 posterior.

You can drop delta0 from the model by centering the data as is commonly done in epidemilogical modeling by replacing log(n_pop) with log(mean(n_birth)/mean(n_pop)*n_pop). After this the sampling behavior even with this short warmup is decent.

I recommend dropping delta0, but to illustrate the stronger dependency in M1 posterior, I did run sampling with iter_warmup = 400, metric = "dense_e" which improves sampling a lot (and is fast enough as there are not too many parameters to make dense mass matrix slow), and while with this short chains rhats are not yer perfect the scatter plot is informative and shows how the bivariate marginal posterior of delta0 and beta_age[1] is very narrow in the middle:

Other comments

You define

  //see https://github.com/avehtari/casestudies/blob/master/Birthdays
  // basis function (exponentiated quadratic kernel)
    matrix PHI_EQ(int N, int M, real L, vector x) {
      matrix[N,M] A = rep_matrix(pi()/(2*L) * (x+L), M);
      vector[M] B = linspaced_vector(M, 1, M);
      matrix[N,M] PHI = sin(diag_post_multiply(A, B))/sqrt(L);
      for (m in 1:M) PHI[,m] = PHI[,m] - mean(PHI[,m]); // scale to have mean 0
      return PHI;
    }

but there shouldn’t be that line

      for (m in 1:M) PHI[,m] = PHI[,m] - mean(PHI[,m]); // scale to have mean 0

I did have such line for a very short time (maybe days), but it’s wrong.

It seems that you have also otherwise quite old version, and I recommend to use version in the git repo you mention

functions {
vector diagSPD_EQ(real alpha, real rho, real L, int M) {
  return alpha * sqrt(sqrt(2*pi()) * rho) * exp(-0.25*(rho*pi()/2/L)^2 * linspaced_vector(M, 1, M)^2);
}
matrix PHI_EQ(int N, int M, real L, vector x) {
  return sin(diag_post_multiply(rep_matrix(pi()/(2*L) * (x+L), M), linspaced_vector(M, 1, M)))/sqrt(L);
}

In stan_data you have

                 c_age = 5,

this is very big, I have usually used 1.5. This big c will cause delta0 to be dependent also with many other odd beta_age. I did my experiments with 1.5.

You seem to have quite informative priors, but data are quite informative and musch weaker priors would be just fine.

Thanks @avehtari for all the insights!

I adapted model M0 following your recommendations:

  1. Updated diagSPD_EQ and PHI_EQ (no scaling),
  2. Sampling with iter_warmup = 400 and metric = "dense_e",
  3. Decreased c_age to 1.5.

I obtained similar results for model M0, with only ~1% divergences (using adapt_delta = 0.8).

For M1, however, ~5% of iterations show divergences (using adapt_delta = 0.8; increasing adapt_delta reduces divergences but does not eliminate them completely). Did you also observe divergences? Note that my beta[1] has the opposite sign compared to yours.

Finally, I ran model M1 with delta0 and replacing log(n_pop) with log(mean(n_birth)/mean(n_pop) * n_pop), which resulted in ~2% divergences.

Am I missing something? Many thanks again.

Probably. It’s quite likely looking at the shape of the posterior.

I don’t think so. Then you can increase number of warmup and sampling iterations and increase adapt_delta a little bit.

You can get some speedup by fixing the lengthscale. alpha and rho are not well identified separately anyway and the result is thus not sensitive to which value lengthscale is fixed (many other basis function approaches use only magnitude parameter anyway, e.g. regularized thin-plate splines)