Chains stuck in a local optimum: correlated Poisson distributions

To represent the correlation between two Poisson distributions, I developed a stan model, which adds the same random effect in the mean parameter of the two Poisson distributions. To model both negative and positive correlations, Iassume that the sign of the random effect is the same when the correlation is positive and opposite when the correlation is negative:

data {
  int<lower=1> N;
  array[N] int y1;
  array[N] int y2;
  int inference;
}

parameters {
  real intercept1;
  real intercept2;
  real sigma_eta;
  array[N] real eta;                   // Random effects
}

transformed parameters{
  real sigma_eta1 = abs(sigma_eta);
  real sigma_eta2 = sigma_eta;
}

model {
  // Priors
  intercept1 ~ normal(2, 1);
  intercept2 ~ normal(2, 1);
  sigma_eta ~ normal(0, 4);
  for(i in 1:N){
    eta[i] ~ std_normal();
  }

  // Likelihood using poisson_log_lpmf
  if(inference==1){
    for (i in 1:N) {
      target += poisson_log_lpmf(y1[i] | intercept1 + eta[i]*sigma_eta1);
      target += poisson_log_lpmf(y2[i] | intercept2 + eta[i]*sigma_eta2);
    }
  }
}

generated quantities {
  real loglik=0;
   for (i in 1:N) {
    loglik += poisson_log_lpmf(y1[i] | intercept1 + eta[i]*sigma_eta1);
    loglik += poisson_log_lpmf(y2[i] | intercept2 + eta[i]*sigma_eta2);
  }
}

I simulated two distributions y_1 and y_2, which are negatively correlated and tried to retrieve the negative value of the correlation with the stan model. When looking at the posterior distribution of the correlation parameter, I noticed a bimodal distribution:

library(cmdstanr)
library(dplyr)

#Simulate data
corr12=-1
df = data.frame(log_mu1=log(10),log_mu2=log(10),corr12=corr12,
                mu_corr = rnorm(100,mean=0,sd=abs(corr12))) %>% 
  dplyr::mutate(mu1=exp(log_mu1 + mu_corr),
                mu2=if_else(corr12<0,exp(log_mu2 - mu_corr),
                            exp(log_mu2 + mu_corr + log(n)))) %>% 
  rowwise() %>% 
  dplyr::mutate(y1=rpois(1,mu1),
                y2=rpois(1,mu2))

#plot of the two correlated Poisson distributions
df %>% 
  ggplot(aes(x=y1,y=y2))+
  geom_point()+
  xlim(c(0,100))+ylim(c(0,100))

data_list=list(N=dim(df)[1],
               y1=df$y1,
               y2=df$y2,
               inference=1)

#Stan model fit
mod3 <- cmdstan_model(paste0(code_root_path,"stan/test3_correlated_poisson.stan"))
fit <- mod3$sample(
  adapt_delta=0.99,
  iter_warmup=2000,
  data = data_list,
  chains = 4, 
  parallel_chains = 4,
  show_messages = TRUE,#FALSE,
  refresh = 100, # print update every 500 iters
)

#High rhat
fit$summary() %>% 
  arrange(-rhat)
#The correlation parameter is not well estimated
fit$summary(variables = c("sigma_eta"), "mean",~quantile(.x, probs = c(0.025, 0.975)))

#Check posterior distributions between chains
d=fit$draws()[,,]
n_iter_per_chain = d[,1,1] %>% length()
d1=as.data.frame(ftable(d[,,grepl("sigma_eta",dimnames(d)[[3]]) &
                            !grepl("sigma_eta2|sigma_eta1",dimnames(d)[[3]])]))
d2=as.data.frame(ftable(d[,,grepl("loglik",dimnames(d)[[3]])]))
d3= full_join(d1 %>% dplyr::select(iteration,chain,sigma_eta=Freq),
          d2 %>% dplyr::select(iteration,chain,loglik=Freq)) 
#Posterior distribution of the correlation parameter
d3 %>% 
  ggplot(aes(x=sigma_eta)) + 
  geom_histogram()
#Two chains with low loglikelihood
d3 %>%
  ggplot(aes(x=sigma_eta,y=loglik,col=chain)) +
  geom_point()+geom_line()

This is because two chains are stuck in a local maximum. How to avoid this? I thought that warmup iterations would prevent this, but even when increasing their number of iteration, I get the same issue.

An alternative approach is using copulas, which ties together the marginal distributions directly (instead of the mean parameter). See Copulas for an example.

In addition to increasing the number of warmup iterations, have you tried changing other algorithm parameters? Since the loglik is quite different it seems like it’s really having trouble getting to a higher probability region, which shouldn’t be that difficult even for a moderately well-tuned sampler.

Thanks, this is a very good idea indeed. I will have a look.

No, I haven’t tried because I woudln’t know how to change these parameters and which ones I should adapt. Any suggestion?

I guess in the sampling statement you can add the keyword adapt_delta and change the default. Stan Documentation has more information of how warmup works, and other parameters that may help. In cmdstanpy you can find more of the sampler parameters here – I guess you are using cmdstanr, but I can’t find the equivalent page right away.

I think that in this case you will be better off modelling the correlated latent residuals using the lkj_corr_cholesky rather than trying to to sign flip an unconstrained sigma. I use my {mvgam} package below to show how this can be done, but of course you could set this up yourself without needing to use my package. This seems to work well for your simulated dataset, with good recovery of the true correlation and no major sampling issues to worry about.

library(mvgam)
library(dplyr)
library(ggplot2); theme_set(theme_bw())

#Simulate data
corr12 <- -1
df = data.frame(log_mu1 = log(10),
                log_mu2 = log(10),
                corr12 = corr12,
                mu_corr = rnorm(100,
                                mean = 0,
                                sd = abs(corr12)),
                n = 100) %>% 
  dplyr::mutate(mu1 = exp(log_mu1 + mu_corr),
                mu2 = if_else(corr12 < 0, 
                              exp(log_mu2 - mu_corr),
                            exp(log_mu2 + mu_corr + log(n)))) %>% 
  rowwise() %>% 
  dplyr::mutate(y1 = rpois(1, mu1),
                y2 = rpois(1, mu2))

# Plot of the two correlated Poisson distributions
df %>% 
  ggplot(aes(x=y1,y=y2))+
  geom_point()+
  xlim(c(0,100))+ylim(c(0,100))


# Convert to 'long' format
data.frame(y = c(df$y1, df$y2),
           variable = as.factor(c(rep('y1', 100),
                                  rep('y2', 100))),
           observation = c(1:100, 1:100)) -> dat

# Fit a model that uses Poisson observations but which
# allows the group-level latent residuals to be correlated;
# 4 parallel chains are run by default, same as in brms
mod <- mvgam(y ~ 1,
             # Correlated, zero-centred latent residuals
             trend_model = ZMVN(unit = observation, 
                                                subgr = variable),
             priors = prior(normal(0, 4),
                                  class = sigma),
             data = dat,
             family = poisson(),
             backend = 'cmdstanr')
#> Compiling Stan program using cmdstanr
#> 
#> Start sampling
#> Running MCMC with 4 parallel chains...
#> 
#> Chain 1 Iteration:   1 / 1000 [  0%]  (Warmup) 
#> Chain 2 Iteration:   1 / 1000 [  0%]  (Warmup) 
#> Chain 3 Iteration:   1 / 1000 [  0%]  (Warmup) 
#> Chain 4 Iteration:   1 / 1000 [  0%]  (Warmup) 
#> Chain 1 Iteration: 100 / 1000 [ 10%]  (Warmup) 
#> Chain 2 Iteration: 100 / 1000 [ 10%]  (Warmup) 
#> Chain 4 Iteration: 100 / 1000 [ 10%]  (Warmup) 
#> Chain 2 Iteration: 200 / 1000 [ 20%]  (Warmup) 
#> Chain 1 Iteration: 200 / 1000 [ 20%]  (Warmup) 
#> Chain 3 Iteration: 100 / 1000 [ 10%]  (Warmup) 
#> Chain 4 Iteration: 200 / 1000 [ 20%]  (Warmup) 
#> Chain 2 Iteration: 300 / 1000 [ 30%]  (Warmup) 
#> Chain 1 Iteration: 300 / 1000 [ 30%]  (Warmup) 
#> Chain 3 Iteration: 200 / 1000 [ 20%]  (Warmup) 
#> Chain 4 Iteration: 300 / 1000 [ 30%]  (Warmup) 
#> Chain 1 Iteration: 400 / 1000 [ 40%]  (Warmup) 
#> Chain 3 Iteration: 300 / 1000 [ 30%]  (Warmup) 
#> Chain 2 Iteration: 400 / 1000 [ 40%]  (Warmup) 
#> Chain 4 Iteration: 400 / 1000 [ 40%]  (Warmup) 
#> Chain 3 Iteration: 400 / 1000 [ 40%]  (Warmup) 
#> Chain 2 Iteration: 500 / 1000 [ 50%]  (Warmup) 
#> Chain 1 Iteration: 500 / 1000 [ 50%]  (Warmup) 
#> Chain 2 Iteration: 501 / 1000 [ 50%]  (Sampling) 
#> Chain 4 Iteration: 500 / 1000 [ 50%]  (Warmup) 
#> Chain 1 Iteration: 501 / 1000 [ 50%]  (Sampling) 
#> Chain 4 Iteration: 501 / 1000 [ 50%]  (Sampling) 
#> Chain 3 Iteration: 500 / 1000 [ 50%]  (Warmup) 
#> Chain 2 Iteration: 600 / 1000 [ 60%]  (Sampling) 
#> Chain 3 Iteration: 501 / 1000 [ 50%]  (Sampling) 
#> Chain 4 Iteration: 600 / 1000 [ 60%]  (Sampling) 
#> Chain 3 Iteration: 600 / 1000 [ 60%]  (Sampling) 
#> Chain 1 Iteration: 600 / 1000 [ 60%]  (Sampling) 
#> Chain 2 Iteration: 700 / 1000 [ 70%]  (Sampling) 
#> Chain 4 Iteration: 700 / 1000 [ 70%]  (Sampling) 
#> Chain 3 Iteration: 700 / 1000 [ 70%]  (Sampling) 
#> Chain 2 Iteration: 800 / 1000 [ 80%]  (Sampling) 
#> Chain 4 Iteration: 800 / 1000 [ 80%]  (Sampling) 
#> Chain 1 Iteration: 700 / 1000 [ 70%]  (Sampling) 
#> Chain 3 Iteration: 800 / 1000 [ 80%]  (Sampling) 
#> Chain 2 Iteration: 900 / 1000 [ 90%]  (Sampling) 
#> Chain 4 Iteration: 900 / 1000 [ 90%]  (Sampling) 
#> Chain 3 Iteration: 900 / 1000 [ 90%]  (Sampling) 
#> Chain 1 Iteration: 800 / 1000 [ 80%]  (Sampling) 
#> Chain 2 Iteration: 1000 / 1000 [100%]  (Sampling) 
#> Chain 2 finished in 5.7 seconds.
#> Chain 4 Iteration: 1000 / 1000 [100%]  (Sampling) 
#> Chain 4 finished in 5.6 seconds.
#> Chain 3 Iteration: 1000 / 1000 [100%]  (Sampling) 
#> Chain 3 finished in 5.9 seconds.
#> Chain 1 Iteration: 900 / 1000 [ 90%]  (Sampling) 
#> Chain 1 Iteration: 1000 / 1000 [100%]  (Sampling) 
#> Chain 1 finished in 7.5 seconds.
#> 
#> All 4 chains finished successfully.
#> Mean chain execution time: 6.2 seconds.
#> Total execution time: 7.8 seconds.

# The Stan code
stancode(mod)
#> // Stan model code generated by package mvgam
#> data {
#>   int<lower=0> total_obs; // total number of observations
#>   int<lower=0> n; // number of timepoints per series
#>   int<lower=0> n_series; // number of series
#>   int<lower=0> num_basis; // total number of basis coefficients
#>   matrix[total_obs, num_basis] X; // mgcv GAM design matrix
#>   array[n, n_series] int<lower=0> ytimes; // time-ordered matrix (which col in X belongs to each [time, series] observation?)
#>   int<lower=0> n_nonmissing; // number of nonmissing observations
#>   array[n_nonmissing] int<lower=0> flat_ys; // flattened nonmissing observations
#>   matrix[n_nonmissing, num_basis] flat_xs; // X values for nonmissing observations
#>   array[n_nonmissing] int<lower=0> obs_ind; // indices of nonmissing observations
#> }
#> transformed data {
#>   vector[n_series] trend_zeros = rep_vector(0.0, n_series);
#> }
#> parameters {
#>   // raw basis coefficients
#>   vector[num_basis] b_raw;
#>   
#>   // latent trend variance parameters
#>   vector<lower=0>[n_series] sigma;
#>   
#>   // correlated latent residuals
#>   array[n] vector[n_series] trend_raw;
#>   cholesky_factor_corr[n_series] L_Omega;
#> }
#> transformed parameters {
#>   matrix[n, n_series] trend;
#>   
#>   // LKJ form of covariance matrix
#>   matrix[n_series, n_series] L_Sigma;
#>   
#>   // basis coefficients
#>   vector[num_basis] b;
#>   b[1 : num_basis] = b_raw[1 : num_basis];
#>   
#>   // correlated residuals
#>   L_Sigma = diag_pre_multiply(sigma, L_Omega);
#>   for (i in 1 : n) {
#>     trend[i, 1 : n_series] = to_row_vector(trend_raw[i]);
#>   }
#> }
#> model {
#>   // prior for (Intercept)...
#>   b_raw[1] ~ student_t(3, 2.3, 2.5);
#>   
#>   // priors for latent trend variance parameters
#>   sigma ~ normal(0, 4);
#>   
#>   // residual error correlations
#>   L_Omega ~ lkj_corr_cholesky(2);
#>   for (i in 1 : n) {
#>     trend_raw[i] ~ multi_normal_cholesky(trend_zeros, L_Sigma);
#>   }
#>   {
#>     // likelihood functions
#>     vector[n_nonmissing] flat_trends;
#>     flat_trends = to_vector(trend)[obs_ind];
#>     flat_ys ~ poisson_log_glm(append_col(flat_xs, flat_trends), 0.0,
#>                               append_row(b, 1.0));
#>   }
#> }
#> generated quantities {
#>   vector[total_obs] eta;
#>   matrix[n, n_series] mus;
#>   vector[n_series] tau;
#>   array[n, n_series] int ypred;
#>   for (s in 1 : n_series) {
#>     tau[s] = pow(sigma[s], -2.0);
#>   }
#>   
#>   // computed error covariance matrix
#>   cov_matrix[n_series] Sigma = multiply_lower_tri_self_transpose(L_Sigma);
#>   
#>   // posterior predictions
#>   eta = X * b;
#>   for (s in 1 : n_series) {
#>     mus[1 : n, s] = eta[ytimes[1 : n, s]] + trend[1 : n, s];
#>     ypred[1 : n, s] = poisson_log_rng(mus[1 : n, s]);
#>   }
#> }

# Diagnostics
summary(mod)
#> GAM formula:
#> y ~ 1
#> 
#> Family:
#> poisson
#> 
#> Link function:
#> log
#> 
#> Trend model:
#> ZMVN(unit = observation, subgr = variable)
#> 
#> 
#> N series:
#> 2 
#> 
#> N timepoints:
#> 100 
#> 
#> Status:
#> Fitted using Stan 
#> 4 chains, each with iter = 1000; warmup = 500; thin = 1 
#> Total post-warmup draws = 2000
#> 
#> 
#> GAM coefficient (beta) estimates:
#>             2.5% 50% 97.5% Rhat n_eff
#> (Intercept)  2.2 2.3   2.3    1  1372
#> 
#> Stan MCMC diagnostics:
#> n_eff / iter looks reasonable for all parameters
#> Rhat looks reasonable for all parameters
#> 1 of 2000 iterations ended with a divergence (0.05%)
#>  *Try running with larger adapt_delta to remove the divergences
#> 0 of 2000 iterations saturated the maximum tree depth of 10 (0%)
#> Chain 4: E-FMI = 0.1947
#>  *E-FMI below 0.2 indicates you may need to reparameterize your model
#> 
#> Samples were drawn using NUTS(diag_e) at Mon Dec 02 9:10:33 AM 2024.
#> For each parameter, n_eff is a crude measure of effective sample size,
#> and Rhat is the potential scale reduction factor on split MCMC chains
#> (at convergence, Rhat = 1)
mcmc_plot(mod, 
          type = 'rhat_hist')
#> `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

mcmc_plot(mod, 
          type = 'neff_hist')
#> `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

mcmc_plot(mod,
          type = 'trace')


# Implied error Variance-Covariance matrix
mcmc_plot(mod, 
          variable = 'Sigma', 
          regex = TRUE, 
          type = 'hist')
#> `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.


# Posterior mean error correlation matrix
Sigmas <- as.matrix(mod, 
                    variable = 'Sigma', 
                    regex = TRUE)
Reduce( 
  '+', 
  lapply(1:NROW(Sigmas), function(x){
    cov2cor(matrix(Sigmas[x,], nrow = 2, ncol = 2))
  })) / NROW(Sigmas
  )
#>            [,1]       [,2]
#> [1,]  1.0000000 -0.9646524
#> [2,] -0.9646524  1.0000000

Created on 2024-12-02 with reprex v2.0.2

2 Likes