Hi all,
The aim of my Stan model is to capture trends in count data over time. It could for example be the number of weekly deaths in a given population. To do so, I used Gaussian processes (GPs) that capture three different trends:
- Long-term trend (f1) capturing the decrease/increase of count data over years
- Periodic trend (f2) capturing the seasonality, using a period of 1 year (50 weeks).
- Short-term trend (f3) capturing the rapid changes in count data.
The lengthscales for the first two GPs are free parameters, while it is fixed for the third GP at very small value (3 weeks) to account for rapid changes.
To check whether the model is able to identify the three GPs, I simulated count data assuming long-term and periodic trends but no short-term trend. The idea was first to check that the model estimates the short-term trend correctly (i.e at 0), when there is no short/temporary changes in the data.
I use the method from Riutort-Mayol et al. to approximate GPs and fit the count data assuming a Poisson distribution.
The graph below shows that the model is able to identify the three GPs.
However, my concern lies in the fact that 300 (100%) transitions hit the maximum treedepth limit of 10 for one chain, while none of the transitions hit the maximum treedepth for the other three chains. No divergences were reported in any of the four chains and all Rhat<1.05. I used 1000 iterations for warmup and 300 for sampling. Note that this does not systematically occur and some runs end with the four chains not hitting maximum treedepth.
I know that this does not question the validity of the model but I would like to understand why it occurs specifically to one chain, as this reduces the efficiency of the chain, usually leading to an 30-40% longer execution time compared with the other chains. As I then want to increase the complexity of the model, it is important to me to make this toy model as efficient as possible.
I use R version 4.2.1 and run the stan model with CmdStanR package v0.5.3 and CmdStan 2.31.0.
You can find the data, code and Stan model below.
Many thanks in advance,
Anthony
Data
deaths1.RData (7.3 KB)
R code
library(cmdstanr)
library(tidyverse)
library(data.table)#data
x = 1:300
x_mean = mean(x)
x_sd = sd(x)
xn = (x-mean(x))/x_sdload(“deaths1.RData”)
pop=100000
time = deaths %>% dplyr::select(week.id) %>% unique()
data_list = list(
N_deaths = dim(deaths)[1],
N_x = length(time$week.id),
N_age = uniqueN(deaths$age.id),
age_id = deaths$age.id,
deaths = deaths$deaths,
week_id = deaths$week.id,pop = pop, x = time$week.id, c1 = 5, M1 = 20, J2 = 20, c3 = 3, M3 = 100, p_lambda1 = c(0,0.4), p_lambda2 = c(log(1)-0.2^2/2,0.2), #this prior might be too wide lambda3 = diff(xn)[1]*3, p_alpha1 = c(0,0.2), p_alpha3 = c(0,0.4), p_alpha2 = c(0,0.1), p_mu0_mean = c(-5), p_mu0_sd = c(0.5), inference = 1)
#Init functions
initfun ← function() { list(lambda1 = rlnorm(data_list$N_age,data_list$p_lambda1[1],data_list$p_lambda1[2]),
lambda2 = rlnorm(data_list$N_age,data_list$p_lambda2[1],data_list$p_lambda2[2]),
alpha1 = abs(rnorm(data_list$N_age,data_list$p_alpha1[1],data_list$p_alpha1[2])),
alpha2 = abs(rnorm(data_list$N_age,data_list$p_alpha2[1],data_list$p_alpha2[2])),
alpha3 = abs(rnorm(data_list$N_age,data_list$p_alpha3[1],data_list$p_alpha3[2])),
mu0= rnorm(data_list$N_age,data_list$p_mu0_mean,data_list$p_mu0_sd)) }#Run Stan
gp_mod2 ← cmdstan_model(“stan/gp_mod2.stan”)
fit2 ← gp_mod2$sample(
init=initfun,
adapt_delta=0.99,
data = data_list,
chains = 4,
parallel_chains = 4,
iter_warmup = 1000,
iter_sampling = 300,
refresh = 200 )fit2$summary(variables = c(“f1”,“f2”,“f3”),
~quantile(.x, probs = c(0.025,0.5, 0.975),na.rm=TRUE)) %>%
as_tibble() %>%
dplyr::select(variable,est=50%
,lwb=2.5%
,upb=97.5%
) %>%
tidyr::separate(col=variable,into=c(“var”,“age.id”,“week.id”),sep=“\,|\[|\]”) %>%
dplyr::mutate(age.id = as.numeric(age.id),
week.id = as.numeric(week.id)) %>%
left_join(deaths %>%
tidyr::pivot_longer(cols=contains(“f”),names_to = “var”,values_to = “gp”) %>%
group_by(var) %>% dplyr::mutate(mean_gp=mean(gp)) %>% ungroup() %>%
dplyr::mutate(gp=gp-mean_gp),
by=c(“var”,“age.id”,“week.id”)) %>%
ggplot(aes(x=week.id)) +
geom_line(aes(y=est,col=var)) +
geom_ribbon(aes(ymin=lwb,ymax=upb,col=var),alpha=0.1) +
geom_point(aes(y=gp),col=“black”)+
geom_line(aes(y=gp),col=“black”)+
facet_grid(var~age.id)+
theme_bw()
Stan model
functions {
//see https://github.com/avehtari/casestudies/blob/master/Birthdays
// basis function (exponentiated quadratic kernel)
matrix PHI_EQ(int N, int M, real L, vector x) {
matrix[N,M] A = rep_matrix(pi()/(2*L) * (x+L), M);
vector[M] B = linspaced_vector(M, 1, M);
matrix[N,M] PHI = sin(diag_post_multiply(A, B))/sqrt(L);
for (m in 1:M) PHI[,m] = PHI[,m] - mean(PHI[,m]); // scale to have mean 0
return PHI;
}
// spectral density (exponentiated quadratic kernel)
vector diagSPD_EQ(real alpha, real lambda, real L, int M) {
vector[M] B = linspaced_vector(M, 1, M);
return sqrt( alpha^2 * sqrt(2*pi()) * lambda * exp(-0.5*(lambda*pi()/(2*L))^2*B^2) );
}
vector diagSPD_periodic(real alpha, real lambda, int M) {
real a = 1/lambda^2;
int one_to_M[M];
for (m in 1:M) one_to_M[m] = m;
vector[M] q = sqrt(alpha^2 * 2 / exp(a) * to_vector(modified_bessel_first_kind(one_to_M, a)));
return append_row(q,q);
}
matrix PHI_periodic(int N, int M, real w0, vector x) {
matrix[N,M] mw0x = diag_post_multiply(rep_matrix(w0*x, M), linspaced_vector(M, 1, M));
return append_col(cos(mw0x), sin(mw0x));
}
}
// load data objects
data {
//dimensions
int N_deaths; //number of datapoint
int N_x; //number of points at which GP is evaluated
int N_age;
//deaths data
array[N_deaths] int age_id;
array[N_deaths] int deaths;
array[N_deaths] int week_id;
//pop data
array[N_age] int pop;
//time
vector[N_x] x; //locations at which GPs are evaluated
//basis function GP
real<lower=0> c1; // factor c to determine the boundary value L
int M1; //number of basis functions
real<lower=0> c3; // factor c to determine the boundary value L
int M3; //number of basis functions
int<lower=1> J2; // number of cos and sin functions for periodic
//Hyperprior parameters
array[2] real p_lambda1;
array[2] real p_lambda2;
array[N_age] real lambda3;
array[2] real p_alpha1;
array[2] real p_alpha2;
array[2] real p_alpha3;
array[N_age] real p_mu0_mean;
array[N_age] real p_mu0_sd;
int inference;
}
transformed data {
// normalize data
real x_mean = mean(x);
real x_sd = sd(x);
vector[N_x] xn = (x - x_mean)/x_sd;
// compute boundary value
real L1 = c1*max(xn);
real L3 = c1*max(xn);
// compute basis functions for f
matrix[N_x,M1] PHI1 = PHI_EQ(N_x, M1, L1, xn);
matrix[N_x,M3] PHI3 = PHI_EQ(N_x, M3, L3, xn);
real period_year = 50/x_sd; #number of weeks divided by sd
matrix[N_x,2*J2] PHI2 = PHI_periodic(N_x, J2, 2*pi()/period_year, xn);
}
parameters {
array[N_age] real mu0; //intercept mortality
array[N_age] vector[M1] beta1; // basis function coefficients for f
array[N_age] real <lower=0> lambda1; // lengthscale of f
array[N_age] real<lower=0> alpha1;
array[N_age] vector[M3] beta3; // basis function coefficients for f
array[N_age] real<lower=0> alpha3;
array[N_age] vector[2*J2] beta2; // basis function coefficients for f
array[N_age] real <lower=0> lambda2; // lengthscale of f
array[N_age] real<lower=0> alpha2;
}
transformed parameters {
array[N_age] vector[M1] diagSPD1;
array[N_age] vector[M3] diagSPD3;
array[N_age] vector[2*J2] diagSPD2;
array[N_age] vector[N_x] f1;
array[N_age] vector[N_x] f3;
array[N_age] vector[N_x] f2;
array[N_age,N_x] real log_mu;
for(i in 1:N_age){
// compute spectral densities for f
diagSPD1[i] = to_vector(diagSPD_EQ(alpha1[i], lambda1[i], L1, M1));
diagSPD3[i] = to_vector(diagSPD_EQ(alpha3[i], lambda3[i], L3, M3));
diagSPD2[i] = to_vector(diagSPD_periodic(alpha2[i], lambda2[i], J2));
// compute f
f1[i] = to_vector(PHI1 * (to_vector(diagSPD1[i]) .* to_vector(beta1[i])));
f3[i] = to_vector(PHI3 * (to_vector(diagSPD3[i]) .* to_vector(beta3[i])));
f2[i] = to_vector(PHI2 * (to_vector(diagSPD2[i]) .* to_vector(beta2[i])));
}
//log mortality rate
for(i in 1:N_age){
for(j in 1:N_x){
log_mu[i,j] = mu0[i] + f1[i,j] + f3[i,j] + f2[i,j] + log(pop[i]);
}
}
}
model {
for(i in 1:N_age){
//intercept
mu0 ~ normal(p_mu0_mean[i],p_mu0_sd[i]);
// GP parameters
beta1[i] ~ normal(0, 1);
lambda1[i] ~ lognormal(p_lambda1[1],p_lambda1[2]);
alpha1[i] ~ normal(p_alpha1[1],p_alpha1[2]); //alpha ~ normal(p_alpha[1], p_alpha[2]);
beta3[i] ~ normal(0, 1);
alpha3[i] ~ normal(p_alpha3[1],p_alpha3[2]); //alpha ~ normal(p_alpha[1], p_alpha[2]);
beta2[i] ~ normal(0, 1);
lambda2[i] ~ lognormal(p_lambda2[1],p_lambda2[2]);
alpha2[i] ~ normal(p_alpha2[1],p_alpha2[2]); //alpha ~ normal(p_alpha[1], p_alpha[2]);
}
// likelihood
if(inference==1){
for(i in 1:N_deaths){
target += poisson_log_lpmf(deaths[i] | log_mu[age_id[i],week_id[i]]);
}
}
}
generated quantities {
array[N_age,N_x] real mu;
array[N_age,N_x] real mu_f1;
array[N_age,N_x] real mu_f3;
array[N_age,N_x] real mu_f2;
for(i in 1:N_age){
for(j in 1:N_x){
mu_f1[i,j] = exp(mu0[i] + f1[i,j]);
mu_f3[i,j] = exp(mu0[i] + f3[i,j]);
mu_f2[i,j] = exp(mu0[i] + f2[i,j]);
mu[i,j] = exp(mu0[i] + f1[i,j] + f3[i,j] + f2[i,j]);
}
}
}