I think that in this case you will be better off modelling the correlated latent residuals using the lkj_corr_cholesky
rather than trying to to sign flip an unconstrained sigma
. I use my {mvgam}
package below to show how this can be done, but of course you could set this up yourself without needing to use my package. This seems to work well for your simulated dataset, with good recovery of the true correlation and no major sampling issues to worry about.
library(mvgam)
library(dplyr)
library(ggplot2); theme_set(theme_bw())
#Simulate data
corr12 <- -1
df = data.frame(log_mu1 = log(10),
log_mu2 = log(10),
corr12 = corr12,
mu_corr = rnorm(100,
mean = 0,
sd = abs(corr12)),
n = 100) %>%
dplyr::mutate(mu1 = exp(log_mu1 + mu_corr),
mu2 = if_else(corr12 < 0,
exp(log_mu2 - mu_corr),
exp(log_mu2 + mu_corr + log(n)))) %>%
rowwise() %>%
dplyr::mutate(y1 = rpois(1, mu1),
y2 = rpois(1, mu2))
# Plot of the two correlated Poisson distributions
df %>%
ggplot(aes(x=y1,y=y2))+
geom_point()+
xlim(c(0,100))+ylim(c(0,100))
# Convert to 'long' format
data.frame(y = c(df$y1, df$y2),
variable = as.factor(c(rep('y1', 100),
rep('y2', 100))),
observation = c(1:100, 1:100)) -> dat
# Fit a model that uses Poisson observations but which
# allows the group-level latent residuals to be correlated;
# 4 parallel chains are run by default, same as in brms
mod <- mvgam(y ~ 1,
# Correlated, zero-centred latent residuals
trend_model = ZMVN(unit = observation,
subgr = variable),
priors = prior(normal(0, 4),
class = sigma),
data = dat,
family = poisson(),
backend = 'cmdstanr')
#> Compiling Stan program using cmdstanr
#>
#> Start sampling
#> Running MCMC with 4 parallel chains...
#>
#> Chain 1 Iteration: 1 / 1000 [ 0%] (Warmup)
#> Chain 2 Iteration: 1 / 1000 [ 0%] (Warmup)
#> Chain 3 Iteration: 1 / 1000 [ 0%] (Warmup)
#> Chain 4 Iteration: 1 / 1000 [ 0%] (Warmup)
#> Chain 1 Iteration: 100 / 1000 [ 10%] (Warmup)
#> Chain 2 Iteration: 100 / 1000 [ 10%] (Warmup)
#> Chain 4 Iteration: 100 / 1000 [ 10%] (Warmup)
#> Chain 2 Iteration: 200 / 1000 [ 20%] (Warmup)
#> Chain 1 Iteration: 200 / 1000 [ 20%] (Warmup)
#> Chain 3 Iteration: 100 / 1000 [ 10%] (Warmup)
#> Chain 4 Iteration: 200 / 1000 [ 20%] (Warmup)
#> Chain 2 Iteration: 300 / 1000 [ 30%] (Warmup)
#> Chain 1 Iteration: 300 / 1000 [ 30%] (Warmup)
#> Chain 3 Iteration: 200 / 1000 [ 20%] (Warmup)
#> Chain 4 Iteration: 300 / 1000 [ 30%] (Warmup)
#> Chain 1 Iteration: 400 / 1000 [ 40%] (Warmup)
#> Chain 3 Iteration: 300 / 1000 [ 30%] (Warmup)
#> Chain 2 Iteration: 400 / 1000 [ 40%] (Warmup)
#> Chain 4 Iteration: 400 / 1000 [ 40%] (Warmup)
#> Chain 3 Iteration: 400 / 1000 [ 40%] (Warmup)
#> Chain 2 Iteration: 500 / 1000 [ 50%] (Warmup)
#> Chain 1 Iteration: 500 / 1000 [ 50%] (Warmup)
#> Chain 2 Iteration: 501 / 1000 [ 50%] (Sampling)
#> Chain 4 Iteration: 500 / 1000 [ 50%] (Warmup)
#> Chain 1 Iteration: 501 / 1000 [ 50%] (Sampling)
#> Chain 4 Iteration: 501 / 1000 [ 50%] (Sampling)
#> Chain 3 Iteration: 500 / 1000 [ 50%] (Warmup)
#> Chain 2 Iteration: 600 / 1000 [ 60%] (Sampling)
#> Chain 3 Iteration: 501 / 1000 [ 50%] (Sampling)
#> Chain 4 Iteration: 600 / 1000 [ 60%] (Sampling)
#> Chain 3 Iteration: 600 / 1000 [ 60%] (Sampling)
#> Chain 1 Iteration: 600 / 1000 [ 60%] (Sampling)
#> Chain 2 Iteration: 700 / 1000 [ 70%] (Sampling)
#> Chain 4 Iteration: 700 / 1000 [ 70%] (Sampling)
#> Chain 3 Iteration: 700 / 1000 [ 70%] (Sampling)
#> Chain 2 Iteration: 800 / 1000 [ 80%] (Sampling)
#> Chain 4 Iteration: 800 / 1000 [ 80%] (Sampling)
#> Chain 1 Iteration: 700 / 1000 [ 70%] (Sampling)
#> Chain 3 Iteration: 800 / 1000 [ 80%] (Sampling)
#> Chain 2 Iteration: 900 / 1000 [ 90%] (Sampling)
#> Chain 4 Iteration: 900 / 1000 [ 90%] (Sampling)
#> Chain 3 Iteration: 900 / 1000 [ 90%] (Sampling)
#> Chain 1 Iteration: 800 / 1000 [ 80%] (Sampling)
#> Chain 2 Iteration: 1000 / 1000 [100%] (Sampling)
#> Chain 2 finished in 5.7 seconds.
#> Chain 4 Iteration: 1000 / 1000 [100%] (Sampling)
#> Chain 4 finished in 5.6 seconds.
#> Chain 3 Iteration: 1000 / 1000 [100%] (Sampling)
#> Chain 3 finished in 5.9 seconds.
#> Chain 1 Iteration: 900 / 1000 [ 90%] (Sampling)
#> Chain 1 Iteration: 1000 / 1000 [100%] (Sampling)
#> Chain 1 finished in 7.5 seconds.
#>
#> All 4 chains finished successfully.
#> Mean chain execution time: 6.2 seconds.
#> Total execution time: 7.8 seconds.
# The Stan code
stancode(mod)
#> // Stan model code generated by package mvgam
#> data {
#> int<lower=0> total_obs; // total number of observations
#> int<lower=0> n; // number of timepoints per series
#> int<lower=0> n_series; // number of series
#> int<lower=0> num_basis; // total number of basis coefficients
#> matrix[total_obs, num_basis] X; // mgcv GAM design matrix
#> array[n, n_series] int<lower=0> ytimes; // time-ordered matrix (which col in X belongs to each [time, series] observation?)
#> int<lower=0> n_nonmissing; // number of nonmissing observations
#> array[n_nonmissing] int<lower=0> flat_ys; // flattened nonmissing observations
#> matrix[n_nonmissing, num_basis] flat_xs; // X values for nonmissing observations
#> array[n_nonmissing] int<lower=0> obs_ind; // indices of nonmissing observations
#> }
#> transformed data {
#> vector[n_series] trend_zeros = rep_vector(0.0, n_series);
#> }
#> parameters {
#> // raw basis coefficients
#> vector[num_basis] b_raw;
#>
#> // latent trend variance parameters
#> vector<lower=0>[n_series] sigma;
#>
#> // correlated latent residuals
#> array[n] vector[n_series] trend_raw;
#> cholesky_factor_corr[n_series] L_Omega;
#> }
#> transformed parameters {
#> matrix[n, n_series] trend;
#>
#> // LKJ form of covariance matrix
#> matrix[n_series, n_series] L_Sigma;
#>
#> // basis coefficients
#> vector[num_basis] b;
#> b[1 : num_basis] = b_raw[1 : num_basis];
#>
#> // correlated residuals
#> L_Sigma = diag_pre_multiply(sigma, L_Omega);
#> for (i in 1 : n) {
#> trend[i, 1 : n_series] = to_row_vector(trend_raw[i]);
#> }
#> }
#> model {
#> // prior for (Intercept)...
#> b_raw[1] ~ student_t(3, 2.3, 2.5);
#>
#> // priors for latent trend variance parameters
#> sigma ~ normal(0, 4);
#>
#> // residual error correlations
#> L_Omega ~ lkj_corr_cholesky(2);
#> for (i in 1 : n) {
#> trend_raw[i] ~ multi_normal_cholesky(trend_zeros, L_Sigma);
#> }
#> {
#> // likelihood functions
#> vector[n_nonmissing] flat_trends;
#> flat_trends = to_vector(trend)[obs_ind];
#> flat_ys ~ poisson_log_glm(append_col(flat_xs, flat_trends), 0.0,
#> append_row(b, 1.0));
#> }
#> }
#> generated quantities {
#> vector[total_obs] eta;
#> matrix[n, n_series] mus;
#> vector[n_series] tau;
#> array[n, n_series] int ypred;
#> for (s in 1 : n_series) {
#> tau[s] = pow(sigma[s], -2.0);
#> }
#>
#> // computed error covariance matrix
#> cov_matrix[n_series] Sigma = multiply_lower_tri_self_transpose(L_Sigma);
#>
#> // posterior predictions
#> eta = X * b;
#> for (s in 1 : n_series) {
#> mus[1 : n, s] = eta[ytimes[1 : n, s]] + trend[1 : n, s];
#> ypred[1 : n, s] = poisson_log_rng(mus[1 : n, s]);
#> }
#> }
# Diagnostics
summary(mod)
#> GAM formula:
#> y ~ 1
#>
#> Family:
#> poisson
#>
#> Link function:
#> log
#>
#> Trend model:
#> ZMVN(unit = observation, subgr = variable)
#>
#>
#> N series:
#> 2
#>
#> N timepoints:
#> 100
#>
#> Status:
#> Fitted using Stan
#> 4 chains, each with iter = 1000; warmup = 500; thin = 1
#> Total post-warmup draws = 2000
#>
#>
#> GAM coefficient (beta) estimates:
#> 2.5% 50% 97.5% Rhat n_eff
#> (Intercept) 2.2 2.3 2.3 1 1372
#>
#> Stan MCMC diagnostics:
#> n_eff / iter looks reasonable for all parameters
#> Rhat looks reasonable for all parameters
#> 1 of 2000 iterations ended with a divergence (0.05%)
#> *Try running with larger adapt_delta to remove the divergences
#> 0 of 2000 iterations saturated the maximum tree depth of 10 (0%)
#> Chain 4: E-FMI = 0.1947
#> *E-FMI below 0.2 indicates you may need to reparameterize your model
#>
#> Samples were drawn using NUTS(diag_e) at Mon Dec 02 9:10:33 AM 2024.
#> For each parameter, n_eff is a crude measure of effective sample size,
#> and Rhat is the potential scale reduction factor on split MCMC chains
#> (at convergence, Rhat = 1)
mcmc_plot(mod,
type = 'rhat_hist')
#> `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
mcmc_plot(mod,
type = 'neff_hist')
#> `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
mcmc_plot(mod,
type = 'trace')
# Implied error Variance-Covariance matrix
mcmc_plot(mod,
variable = 'Sigma',
regex = TRUE,
type = 'hist')
#> `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
# Posterior mean error correlation matrix
Sigmas <- as.matrix(mod,
variable = 'Sigma',
regex = TRUE)
Reduce(
'+',
lapply(1:NROW(Sigmas), function(x){
cov2cor(matrix(Sigmas[x,], nrow = 2, ncol = 2))
})) / NROW(Sigmas
)
#> [,1] [,2]
#> [1,] 1.0000000 -0.9646524
#> [2,] -0.9646524 1.0000000
Created on 2024-12-02 with reprex v2.0.2