Gaussian processes in Stan: all the transitions of one chain hit maximum treedepth, but none in the other chains

Hi all,
The aim of my Stan model is to capture trends in count data over time. It could for example be the number of weekly deaths in a given population. To do so, I used Gaussian processes (GPs) that capture three different trends:

  • Long-term trend (f1) capturing the decrease/increase of count data over years
  • Periodic trend (f2) capturing the seasonality, using a period of 1 year (50 weeks).
  • Short-term trend (f3) capturing the rapid changes in count data.
    The lengthscales for the first two GPs are free parameters, while it is fixed for the third GP at very small value (3 weeks) to account for rapid changes.
    To check whether the model is able to identify the three GPs, I simulated count data assuming long-term and periodic trends but no short-term trend. The idea was first to check that the model estimates the short-term trend correctly (i.e at 0), when there is no short/temporary changes in the data.
    I use the method from Riutort-Mayol et al. to approximate GPs and fit the count data assuming a Poisson distribution.
    The graph below shows that the model is able to identify the three GPs.

However, my concern lies in the fact that 300 (100%) transitions hit the maximum treedepth limit of 10 for one chain, while none of the transitions hit the maximum treedepth for the other three chains. No divergences were reported in any of the four chains and all Rhat<1.05. I used 1000 iterations for warmup and 300 for sampling. Note that this does not systematically occur and some runs end with the four chains not hitting maximum treedepth.

I know that this does not question the validity of the model but I would like to understand why it occurs specifically to one chain, as this reduces the efficiency of the chain, usually leading to an 30-40% longer execution time compared with the other chains. As I then want to increase the complexity of the model, it is important to me to make this toy model as efficient as possible.

I use R version 4.2.1 and run the stan model with CmdStanR package v0.5.3 and CmdStan 2.31.0.

You can find the data, code and Stan model below.
Many thanks in advance,
Anthony


Data
deaths1.RData (7.3 KB)

R code

library(cmdstanr)
library(tidyverse)
library(data.table)

#data
x = 1:300
x_mean = mean(x)
x_sd = sd(x)
xn = (x-mean(x))/x_sd

load(“deaths1.RData”)

pop=100000
time = deaths %>% dplyr::select(week.id) %>% unique()
data_list = list(
N_deaths = dim(deaths)[1],
N_x = length(time$week.id),
N_age = uniqueN(deaths$age.id),
age_id = deaths$age.id,
deaths = deaths$deaths,
week_id = deaths$week.id,

             pop = pop,
             
             x = time$week.id,
             
             c1 = 5,
             M1 = 20,
             J2 = 20,
             c3 = 3,
             M3 = 100,
             
             p_lambda1 = c(0,0.4),
             p_lambda2 = c(log(1)-0.2^2/2,0.2), #this prior might be too wide
             lambda3 = diff(xn)[1]*3,
             p_alpha1 = c(0,0.2), 
             p_alpha3  = c(0,0.4),
             p_alpha2  = c(0,0.1),
             p_mu0_mean = c(-5),
             p_mu0_sd = c(0.5),
             
             inference = 1)

#Init functions
initfun ← function() { list(lambda1 = rlnorm(data_list$N_age,data_list$p_lambda1[1],data_list$p_lambda1[2]),
lambda2 = rlnorm(data_list$N_age,data_list$p_lambda2[1],data_list$p_lambda2[2]),
alpha1 = abs(rnorm(data_list$N_age,data_list$p_alpha1[1],data_list$p_alpha1[2])),
alpha2 = abs(rnorm(data_list$N_age,data_list$p_alpha2[1],data_list$p_alpha2[2])),
alpha3 = abs(rnorm(data_list$N_age,data_list$p_alpha3[1],data_list$p_alpha3[2])),
mu0= rnorm(data_list$N_age,data_list$p_mu0_mean,data_list$p_mu0_sd)) }

#Run Stan
gp_mod2 ← cmdstan_model(“stan/gp_mod2.stan”)
fit2 ← gp_mod2$sample(
init=initfun,
adapt_delta=0.99,
data = data_list,
chains = 4,
parallel_chains = 4,
iter_warmup = 1000,
iter_sampling = 300,
refresh = 200 )

fit2$summary(variables = c(“f1”,“f2”,“f3”),
~quantile(.x, probs = c(0.025,0.5, 0.975),na.rm=TRUE)) %>%
as_tibble() %>%
dplyr::select(variable,est=50%,lwb=2.5%,upb=97.5%) %>%
tidyr::separate(col=variable,into=c(“var”,“age.id”,“week.id”),sep=“\,|\[|\]”) %>%
dplyr::mutate(age.id = as.numeric(age.id),
week.id = as.numeric(week.id)) %>%
left_join(deaths %>%
tidyr::pivot_longer(cols=contains(“f”),names_to = “var”,values_to = “gp”) %>%
group_by(var) %>% dplyr::mutate(mean_gp=mean(gp)) %>% ungroup() %>%
dplyr::mutate(gp=gp-mean_gp),
by=c(“var”,“age.id”,“week.id”)) %>%
ggplot(aes(x=week.id)) +
geom_line(aes(y=est,col=var)) +
geom_ribbon(aes(ymin=lwb,ymax=upb,col=var),alpha=0.1) +
geom_point(aes(y=gp),col=“black”)+
geom_line(aes(y=gp),col=“black”)+
facet_grid(var~age.id)+
theme_bw()

Stan model

functions {
  //see https://github.com/avehtari/casestudies/blob/master/Birthdays
  // basis function (exponentiated quadratic kernel)
    matrix PHI_EQ(int N, int M, real L, vector x) {
      matrix[N,M] A = rep_matrix(pi()/(2*L) * (x+L), M);
      vector[M] B = linspaced_vector(M, 1, M);
      matrix[N,M] PHI = sin(diag_post_multiply(A, B))/sqrt(L);
      for (m in 1:M) PHI[,m] = PHI[,m] - mean(PHI[,m]); // scale to have mean 0
      return PHI;
    }
  // spectral density (exponentiated quadratic kernel)
  vector diagSPD_EQ(real alpha, real lambda, real L, int M) {
    vector[M] B = linspaced_vector(M, 1, M);
    return sqrt( alpha^2 * sqrt(2*pi()) * lambda * exp(-0.5*(lambda*pi()/(2*L))^2*B^2) );
  }
  
  vector diagSPD_periodic(real alpha, real lambda, int M) {
    real a = 1/lambda^2;
    int one_to_M[M];
    for (m in 1:M) one_to_M[m] = m;
    vector[M] q = sqrt(alpha^2 * 2 / exp(a) * to_vector(modified_bessel_first_kind(one_to_M, a)));
    return append_row(q,q);
  }
  
  matrix PHI_periodic(int N, int M, real w0, vector x) {
    matrix[N,M] mw0x = diag_post_multiply(rep_matrix(w0*x, M), linspaced_vector(M, 1, M));
    return append_col(cos(mw0x), sin(mw0x));
  }
}

// load data objects
data {
  //dimensions
  int N_deaths; //number of datapoint
  int N_x; //number of points at which GP is evaluated
  int N_age;
  
  //deaths data
  array[N_deaths] int age_id;
  array[N_deaths] int deaths;
  array[N_deaths] int week_id;
  
  //pop data
  array[N_age] int pop;
  
  //time
  vector[N_x] x; //locations at which GPs are evaluated
  
  //basis function GP
  real<lower=0> c1; // factor c to determine the boundary value L
  int M1; //number of basis functions
  real<lower=0> c3; // factor c to determine the boundary value L
  int M3; //number of basis functions
  int<lower=1> J2;   // number of cos and sin functions for periodic
  
  //Hyperprior parameters
  array[2] real p_lambda1;
  array[2] real p_lambda2;
  array[N_age] real lambda3;
  array[2] real p_alpha1;
  array[2] real p_alpha2;
  array[2] real p_alpha3;
  array[N_age] real p_mu0_mean;
  array[N_age] real p_mu0_sd;
  
  int inference;
}

transformed data {
  // normalize data
  real x_mean = mean(x);
  real x_sd = sd(x);
  vector[N_x] xn = (x - x_mean)/x_sd;
  
  // compute boundary value
  real L1 = c1*max(xn);
  real L3 = c1*max(xn);
  
  // compute basis functions for f
  matrix[N_x,M1] PHI1 = PHI_EQ(N_x, M1, L1, xn);
  matrix[N_x,M3] PHI3 = PHI_EQ(N_x, M3, L3, xn);
  real period_year = 50/x_sd; #number of weeks divided by sd
  matrix[N_x,2*J2] PHI2 = PHI_periodic(N_x, J2, 2*pi()/period_year, xn);
}


parameters {
  array[N_age] real mu0; //intercept mortality
  array[N_age] vector[M1] beta1; // basis function coefficients for f
  array[N_age] real <lower=0> lambda1;      // lengthscale of f
  array[N_age] real<lower=0> alpha1;
  array[N_age] vector[M3] beta3; // basis function coefficients for f
  array[N_age] real<lower=0> alpha3;
  array[N_age] vector[2*J2] beta2; // basis function coefficients for f
  array[N_age] real <lower=0> lambda2;      // lengthscale of f
  array[N_age] real<lower=0> alpha2;
}


transformed parameters {
  array[N_age] vector[M1] diagSPD1;
  array[N_age] vector[M3] diagSPD3;
  array[N_age] vector[2*J2] diagSPD2;
  array[N_age] vector[N_x] f1;
  array[N_age] vector[N_x] f3;
  array[N_age] vector[N_x] f2;
  array[N_age,N_x] real log_mu;
  
  for(i in 1:N_age){
    // compute spectral densities for f
    diagSPD1[i] = to_vector(diagSPD_EQ(alpha1[i], lambda1[i], L1, M1));
    diagSPD3[i] = to_vector(diagSPD_EQ(alpha3[i], lambda3[i], L3, M3));
    diagSPD2[i] = to_vector(diagSPD_periodic(alpha2[i], lambda2[i], J2));
    
    // compute f
    f1[i] = to_vector(PHI1 * (to_vector(diagSPD1[i]) .* to_vector(beta1[i])));
    f3[i] = to_vector(PHI3 * (to_vector(diagSPD3[i]) .* to_vector(beta3[i])));
    f2[i] = to_vector(PHI2 * (to_vector(diagSPD2[i]) .* to_vector(beta2[i])));
  }
  
  //log mortality rate
  for(i in 1:N_age){
    for(j in 1:N_x){
      log_mu[i,j] = mu0[i] + f1[i,j] + f3[i,j] + f2[i,j] + log(pop[i]);
    }
  }
}

model {
  for(i in 1:N_age){
    //intercept
    mu0 ~ normal(p_mu0_mean[i],p_mu0_sd[i]);
    // GP parameters
    beta1[i] ~ normal(0, 1);
    lambda1[i] ~ lognormal(p_lambda1[1],p_lambda1[2]);
    alpha1[i] ~ normal(p_alpha1[1],p_alpha1[2]); //alpha ~ normal(p_alpha[1], p_alpha[2]);
    beta3[i] ~ normal(0, 1);
    alpha3[i] ~ normal(p_alpha3[1],p_alpha3[2]); //alpha ~ normal(p_alpha[1], p_alpha[2]);
    beta2[i] ~ normal(0, 1);
    lambda2[i] ~ lognormal(p_lambda2[1],p_lambda2[2]);
    alpha2[i] ~ normal(p_alpha2[1],p_alpha2[2]); //alpha ~ normal(p_alpha[1], p_alpha[2]);
  }
  
  // likelihood
  if(inference==1){
    for(i in 1:N_deaths){
      target += poisson_log_lpmf(deaths[i] |  log_mu[age_id[i],week_id[i]]);
    }
  }
}

generated quantities {
  array[N_age,N_x] real mu;
  array[N_age,N_x] real mu_f1;
  array[N_age,N_x] real mu_f3;
  array[N_age,N_x] real mu_f2;
  for(i in 1:N_age){
    for(j in 1:N_x){
      mu_f1[i,j] = exp(mu0[i] + f1[i,j]);
      mu_f3[i,j] = exp(mu0[i] + f3[i,j]);
      mu_f2[i,j] = exp(mu0[i] + f2[i,j]);
      mu[i,j] = exp(mu0[i] + f1[i,j] + f3[i,j] + f2[i,j]);
    }
  }
}

Unfortunately, this can happen with random initialization in complex models. The problem is almost always initialization somewhere in the tail that leads to stiff Hamiltonians (varying scales of different dimensions), which in turns leads to small step sizes and the behavior you see.

The only solution we know is to provide better initializations that are closer to the bulk of the posterior probability mass. The only trick is that we still want them to be somewhat diffuse in order to use \widehat{R} as a convergence diagnostic.

1 Like

You may want to check if the step size and mass matrix optimized during warm up are very different for that one chain, as @Bob_Carpenter mentioned, depending on the parameters you start with and and how long a warm up you have (whether it’s enough to get to good step sizes and mass matrices across chains, since each is optimized independently) you may end up with very different method parameters for each chain. The rogue chain may be trying to “compensate” a bad metric by taking more steps and running into that sort of trouble.

I recently released a suite of Markov chain Monte Carlo diagnostics which include some functions for investigating differences in individual chain behavior, including the Hamiltonian Monte Carlo adaptation and the resulting exploration, GitHub - betanalpha/mcmc_diagnostics: Markov chain Monte Carlo general, and Hamiltonian Monte Carlo specific, diagnostics for Stan. The diagnostics are not implemented for CmdStanR but the RStan implementation shouldn’t be too hard to tweak. In any case these will allow you to pretty quickly see what might be causing the inter-chain discrepancies.

3 Likes