You can also use either {mvgam}
or {brms}
to autogenerate Stan code for regression models with latent, autocorrelated residual processes. Below is a reprex for how to do this in {mvgam}
, though the workflow is very similar in {brms}
. Note that this code is slightly more complex than you would need because it handles a wide variety of predictor effects, but it should give you the general idea. In general it is recommended to put a reasonable prior on the AR coefficients, and perhaps restrict them to the stationary region (which is what both {mvgam}
and {brms}
do by default). I also show in the second model how you can use a noncentred parameterisation for the latent AR(1) process, which often leads to better mixing and more effective samples per iteration for this type of model.
# Load the mvgam library
library(mvgam)
#> Welcome to mvgam. Please cite as: Clark, NJ, and Wells, K. 2022. Dynamic Generalized Additive Models (DGAMs) for forecasting discrete ecological time series. Methods in Ecology and Evolution, 2022, https://doi.org/10.1111/2041-210X.13974
# Simulate integer-valued observations over a latent, real-valued AR(1) process
set.seed(0)
T <- 100
phi <- 0.75
sigma <- 0.5
alpha <- 1.25
loglambda <- vector(length = T)
loglambda[1] <- rnorm(1, mean = 0, sd = sigma)
for (t in 2 : T) {
loglambda[t] <- rnorm(1, mean = phi * loglambda[t - 1],
sd = sigma)
}
# Plot the real-valued latent AR(1) process
plot(loglambda, type = 'l', xlab = 'Time')
# Take Poisson observations (using a log link function) and plot
y <- rpois(T, lambda = exp(alpha + loglambda))
plot(y, type = 'l', xlab = 'Time')
# Gather data into a data.frame
dat <- data.frame(y = y,
time = 1:T)
# Fit a Poisson AR(1) model using the standard (centred) parameterisation
mod <- mvgam(y ~ 1,
trend_model = AR(p = 1),
family = poisson(),
data = dat)
#> Your model may benefit from using "noncentred = TRUE"
#> Compiling Stan program using cmdstanr
#>
#> Start sampling
#> Running MCMC with 4 parallel chains...
#>
#> Chain 1 Iteration: 1 / 1000 [ 0%] (Warmup)
#> Chain 1 Iteration: 100 / 1000 [ 10%] (Warmup)
#> Chain 1 Iteration: 200 / 1000 [ 20%] (Warmup)
#> Chain 1 Iteration: 300 / 1000 [ 30%] (Warmup)
#> Chain 1 Iteration: 400 / 1000 [ 40%] (Warmup)
#> Chain 1 Iteration: 500 / 1000 [ 50%] (Warmup)
#> Chain 1 Iteration: 501 / 1000 [ 50%] (Sampling)
#> Chain 1 Iteration: 600 / 1000 [ 60%] (Sampling)
#> Chain 1 Iteration: 700 / 1000 [ 70%] (Sampling)
#> Chain 2 Iteration: 1 / 1000 [ 0%] (Warmup)
#> Chain 2 Iteration: 100 / 1000 [ 10%] (Warmup)
#> Chain 2 Iteration: 200 / 1000 [ 20%] (Warmup)
#> Chain 2 Iteration: 300 / 1000 [ 30%] (Warmup)
#> Chain 2 Iteration: 400 / 1000 [ 40%] (Warmup)
#> Chain 2 Iteration: 500 / 1000 [ 50%] (Warmup)
#> Chain 2 Iteration: 501 / 1000 [ 50%] (Sampling)
#> Chain 2 Iteration: 600 / 1000 [ 60%] (Sampling)
#> Chain 3 Iteration: 1 / 1000 [ 0%] (Warmup)
#> Chain 3 Iteration: 100 / 1000 [ 10%] (Warmup)
#> Chain 3 Iteration: 200 / 1000 [ 20%] (Warmup)
#> Chain 3 Iteration: 300 / 1000 [ 30%] (Warmup)
#> Chain 4 Iteration: 1 / 1000 [ 0%] (Warmup)
#> Chain 4 Iteration: 100 / 1000 [ 10%] (Warmup)
#> Chain 1 Iteration: 800 / 1000 [ 80%] (Sampling)
#> Chain 1 Iteration: 900 / 1000 [ 90%] (Sampling)
#> Chain 1 Iteration: 1000 / 1000 [100%] (Sampling)
#> Chain 1 finished in 0.7 seconds.
#> Chain 2 Iteration: 700 / 1000 [ 70%] (Sampling)
#> Chain 2 Iteration: 800 / 1000 [ 80%] (Sampling)
#> Chain 2 Iteration: 900 / 1000 [ 90%] (Sampling)
#> Chain 2 Iteration: 1000 / 1000 [100%] (Sampling)
#> Chain 3 Iteration: 400 / 1000 [ 40%] (Warmup)
#> Chain 3 Iteration: 500 / 1000 [ 50%] (Warmup)
#> Chain 3 Iteration: 501 / 1000 [ 50%] (Sampling)
#> Chain 3 Iteration: 600 / 1000 [ 60%] (Sampling)
#> Chain 3 Iteration: 700 / 1000 [ 70%] (Sampling)
#> Chain 3 Iteration: 800 / 1000 [ 80%] (Sampling)
#> Chain 4 Iteration: 200 / 1000 [ 20%] (Warmup)
#> Chain 4 Iteration: 300 / 1000 [ 30%] (Warmup)
#> Chain 4 Iteration: 400 / 1000 [ 40%] (Warmup)
#> Chain 4 Iteration: 500 / 1000 [ 50%] (Warmup)
#> Chain 4 Iteration: 501 / 1000 [ 50%] (Sampling)
#> Chain 4 Iteration: 600 / 1000 [ 60%] (Sampling)
#> Chain 4 Iteration: 700 / 1000 [ 70%] (Sampling)
#> Chain 2 finished in 0.6 seconds.
#> Chain 3 Iteration: 900 / 1000 [ 90%] (Sampling)
#> Chain 4 Iteration: 800 / 1000 [ 80%] (Sampling)
#> Chain 4 Iteration: 900 / 1000 [ 90%] (Sampling)
#> Chain 3 Iteration: 1000 / 1000 [100%] (Sampling)
#> Chain 4 Iteration: 1000 / 1000 [100%] (Sampling)
#> Chain 3 finished in 0.7 seconds.
#> Chain 4 finished in 0.6 seconds.
#>
#> All 4 chains finished successfully.
#> Mean chain execution time: 0.7 seconds.
#> Total execution time: 1.3 seconds.
# Inspect the auto-generated Stan code
stancode(mod)
#> // Stan model code generated by package mvgam
#> data {
#> int<lower=0> total_obs; // total number of observations
#> int<lower=0> n; // number of timepoints per series
#> int<lower=0> n_series; // number of series
#> int<lower=0> num_basis; // total number of basis coefficients
#> matrix[total_obs, num_basis] X; // mgcv GAM design matrix
#> array[n, n_series] int<lower=0> ytimes; // time-ordered matrix (which col in X belongs to each [time, series] observation?)
#> int<lower=0> n_nonmissing; // number of nonmissing observations
#> array[n_nonmissing] int<lower=0> flat_ys; // flattened nonmissing observations
#> matrix[n_nonmissing, num_basis] flat_xs; // X values for nonmissing observations
#> array[n_nonmissing] int<lower=0> obs_ind; // indices of nonmissing observations
#> }
#> parameters {
#> // raw basis coefficients
#> vector[num_basis] b_raw;
#>
#> // latent trend AR1 terms
#> vector<lower=-1, upper=1>[n_series] ar1;
#>
#> // latent trend variance parameters
#> vector<lower=0>[n_series] sigma;
#>
#> // latent trends
#> matrix[n, n_series] trend;
#> }
#> transformed parameters {
#> // basis coefficients
#> vector[num_basis] b;
#> b[1 : num_basis] = b_raw[1 : num_basis];
#> }
#> model {
#> // prior for (Intercept)...
#> b_raw[1] ~ student_t(3, 1.1, 2.5);
#>
#> // priors for AR parameters
#> ar1 ~ std_normal();
#>
#> // priors for latent trend variance parameters
#> sigma ~ student_t(3, 0, 2.5);
#>
#> // trend estimates
#> trend[1, 1 : n_series] ~ normal(0, sigma);
#> for (s in 1 : n_series) {
#> trend[2 : n, s] ~ normal(ar1[s] * trend[1 : (n - 1), s], sigma[s]);
#> }
#> {
#> // likelihood functions
#> vector[n_nonmissing] flat_trends;
#> flat_trends = to_vector(trend)[obs_ind];
#> flat_ys ~ poisson_log_glm(append_col(flat_xs, flat_trends), 0.0,
#> append_row(b, 1.0));
#> }
#> }
#> generated quantities {
#> vector[total_obs] eta;
#> matrix[n, n_series] mus;
#> vector[n_series] tau;
#> array[n, n_series] int ypred;
#> for (s in 1 : n_series) {
#> tau[s] = pow(sigma[s], -2.0);
#> }
#>
#> // posterior predictions
#> eta = X * b;
#> for (s in 1 : n_series) {
#> mus[1 : n, s] = eta[ytimes[1 : n, s]] + trend[1 : n, s];
#> ypred[1 : n, s] = poisson_log_rng(mus[1 : n, s]);
#> }
#> }
# Model summary and diagnostics
summary(mod)
#> GAM formula:
#> y ~ 1
#>
#> Family:
#> poisson
#>
#> Link function:
#> log
#>
#> Trend model:
#> AR(p = 1)
#>
#>
#> N series:
#> 1
#>
#> N timepoints:
#> 100
#>
#> Status:
#> Fitted using Stan
#> 4 chains, each with iter = 1000; warmup = 500; thin = 1
#> Total post-warmup draws = 2000
#>
#>
#> GAM coefficient (beta) estimates:
#> 2.5% 50% 97.5% Rhat n_eff
#> (Intercept) 0.86 1.2 1.6 1.02 159
#>
#> Latent trend parameter AR estimates:
#> 2.5% 50% 97.5% Rhat n_eff
#> ar1[1] 0.54 0.75 0.91 1.00 372
#> sigma[1] 0.33 0.47 0.64 1.01 278
#>
#> Stan MCMC diagnostics:
#> n_eff / iter looks reasonable for all parameters
#> Rhat looks reasonable for all parameters
#> 0 of 2000 iterations ended with a divergence (0%)
#> 0 of 2000 iterations saturated the maximum tree depth of 10 (0%)
#> E-FMI indicated no pathological behavior
#>
#> Samples were drawn using NUTS(diag_e) at Sat Nov 16 8:47:43 AM 2024.
#> For each parameter, n_eff is a crude measure of effective sample size,
#> and Rhat is the potential scale reduction factor on split MCMC chains
#> (at convergence, Rhat = 1)
mcmc_plot(mod, type = 'neff_hist')
#> `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
mcmc_plot(mod,
variable = c('ar1', 'sigma', 'Intercept'),
regex = TRUE,
type = 'hist')
#> `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
plot(mod, type = 'forecast')
# Harder to read the code, but often a non-centred parameterisation
# works better (i.e. more effective samples per iteration)
# for latent autoregressive processes
mod2 <- mvgam(y ~ 1,
trend_model = AR(p = 1),
family = poisson(),
noncentred = TRUE,
data = dat)
#> Compiling Stan program using cmdstanr
#>
#> Start sampling
#> Running MCMC with 4 parallel chains...
#>
#> Chain 1 Iteration: 1 / 1000 [ 0%] (Warmup)
#> Chain 1 Iteration: 100 / 1000 [ 10%] (Warmup)
#> Chain 1 Iteration: 200 / 1000 [ 20%] (Warmup)
#> Chain 1 Iteration: 300 / 1000 [ 30%] (Warmup)
#> Chain 1 Iteration: 400 / 1000 [ 40%] (Warmup)
#> Chain 1 Iteration: 500 / 1000 [ 50%] (Warmup)
#> Chain 1 Iteration: 501 / 1000 [ 50%] (Sampling)
#> Chain 1 Iteration: 600 / 1000 [ 60%] (Sampling)
#> Chain 1 Iteration: 700 / 1000 [ 70%] (Sampling)
#> Chain 2 Iteration: 1 / 1000 [ 0%] (Warmup)
#> Chain 2 Iteration: 100 / 1000 [ 10%] (Warmup)
#> Chain 2 Iteration: 200 / 1000 [ 20%] (Warmup)
#> Chain 2 Iteration: 300 / 1000 [ 30%] (Warmup)
#> Chain 2 Iteration: 400 / 1000 [ 40%] (Warmup)
#> Chain 2 Iteration: 500 / 1000 [ 50%] (Warmup)
#> Chain 2 Iteration: 501 / 1000 [ 50%] (Sampling)
#> Chain 3 Iteration: 1 / 1000 [ 0%] (Warmup)
#> Chain 3 Iteration: 100 / 1000 [ 10%] (Warmup)
#> Chain 3 Iteration: 200 / 1000 [ 20%] (Warmup)
#> Chain 4 Iteration: 1 / 1000 [ 0%] (Warmup)
#> Chain 1 Iteration: 800 / 1000 [ 80%] (Sampling)
#> Chain 1 Iteration: 900 / 1000 [ 90%] (Sampling)
#> Chain 1 Iteration: 1000 / 1000 [100%] (Sampling)
#> Chain 2 Iteration: 600 / 1000 [ 60%] (Sampling)
#> Chain 2 Iteration: 700 / 1000 [ 70%] (Sampling)
#> Chain 2 Iteration: 800 / 1000 [ 80%] (Sampling)
#> Chain 3 Iteration: 300 / 1000 [ 30%] (Warmup)
#> Chain 3 Iteration: 400 / 1000 [ 40%] (Warmup)
#> Chain 3 Iteration: 500 / 1000 [ 50%] (Warmup)
#> Chain 3 Iteration: 501 / 1000 [ 50%] (Sampling)
#> Chain 3 Iteration: 600 / 1000 [ 60%] (Sampling)
#> Chain 3 Iteration: 700 / 1000 [ 70%] (Sampling)
#> Chain 4 Iteration: 100 / 1000 [ 10%] (Warmup)
#> Chain 4 Iteration: 200 / 1000 [ 20%] (Warmup)
#> Chain 4 Iteration: 300 / 1000 [ 30%] (Warmup)
#> Chain 4 Iteration: 400 / 1000 [ 40%] (Warmup)
#> Chain 4 Iteration: 500 / 1000 [ 50%] (Warmup)
#> Chain 4 Iteration: 501 / 1000 [ 50%] (Sampling)
#> Chain 1 finished in 0.7 seconds.
#> Chain 2 Iteration: 900 / 1000 [ 90%] (Sampling)
#> Chain 2 Iteration: 1000 / 1000 [100%] (Sampling)
#> Chain 3 Iteration: 800 / 1000 [ 80%] (Sampling)
#> Chain 4 Iteration: 600 / 1000 [ 60%] (Sampling)
#> Chain 4 Iteration: 700 / 1000 [ 70%] (Sampling)
#> Chain 2 finished in 0.7 seconds.
#> Chain 3 Iteration: 900 / 1000 [ 90%] (Sampling)
#> Chain 3 Iteration: 1000 / 1000 [100%] (Sampling)
#> Chain 4 Iteration: 800 / 1000 [ 80%] (Sampling)
#> Chain 3 finished in 0.7 seconds.
#> Chain 4 Iteration: 900 / 1000 [ 90%] (Sampling)
#> Chain 4 Iteration: 1000 / 1000 [100%] (Sampling)
#> Chain 4 finished in 0.7 seconds.
#>
#> All 4 chains finished successfully.
#> Mean chain execution time: 0.7 seconds.
#> Total execution time: 1.4 seconds.
stancode(mod2)
#> // Stan model code generated by package mvgam
#> data {
#> int<lower=0> total_obs; // total number of observations
#> int<lower=0> n; // number of timepoints per series
#> int<lower=0> n_series; // number of series
#> int<lower=0> num_basis; // total number of basis coefficients
#> matrix[total_obs, num_basis] X; // mgcv GAM design matrix
#> array[n, n_series] int<lower=0> ytimes; // time-ordered matrix (which col in X belongs to each [time, series] observation?)
#> int<lower=0> n_nonmissing; // number of nonmissing observations
#> array[n_nonmissing] int<lower=0> flat_ys; // flattened nonmissing observations
#> matrix[n_nonmissing, num_basis] flat_xs; // X values for nonmissing observations
#> array[n_nonmissing] int<lower=0> obs_ind; // indices of nonmissing observations
#> }
#> parameters {
#> // raw basis coefficients
#> vector[num_basis] b_raw;
#>
#> // latent trend AR1 terms
#> vector<lower=-1, upper=1>[n_series] ar1;
#>
#> // latent trend variance parameters
#> vector<lower=0>[n_series] sigma;
#>
#> // raw latent trends
#> matrix[n, n_series] trend_raw;
#> }
#> transformed parameters {
#> // basis coefficients
#> vector[num_basis] b;
#>
#> // latent trends
#> matrix[n, n_series] trend;
#> trend = trend_raw .* rep_matrix(sigma', rows(trend_raw));
#> for (s in 1 : n_series) {
#> trend[2 : n, s] += ar1[s] * trend[1 : (n - 1), s];
#> }
#> b[1 : num_basis] = b_raw[1 : num_basis];
#> }
#> model {
#> // prior for (Intercept)...
#> b_raw[1] ~ student_t(3, 1.1, 2.5);
#>
#> // priors for AR parameters
#> ar1 ~ std_normal();
#>
#> // priors for latent trend variance parameters
#> sigma ~ student_t(3, 0, 2.5);
#> to_vector(trend_raw) ~ std_normal();
#> {
#> // likelihood functions
#> vector[n_nonmissing] flat_trends;
#> flat_trends = to_vector(trend)[obs_ind];
#> flat_ys ~ poisson_log_glm(append_col(flat_xs, flat_trends), 0.0,
#> append_row(b, 1.0));
#> }
#> }
#> generated quantities {
#> vector[total_obs] eta;
#> matrix[n, n_series] mus;
#> vector[n_series] tau;
#> array[n, n_series] int ypred;
#> for (s in 1 : n_series) {
#> tau[s] = pow(sigma[s], -2.0);
#> }
#>
#> // posterior predictions
#> eta = X * b;
#> for (s in 1 : n_series) {
#> mus[1 : n, s] = eta[ytimes[1 : n, s]] + trend[1 : n, s];
#> ypred[1 : n, s] = poisson_log_rng(mus[1 : n, s]);
#> }
#> }
summary(mod2)
#> GAM formula:
#> y ~ 1
#>
#> Family:
#> poisson
#>
#> Link function:
#> log
#>
#> Trend model:
#> AR(p = 1)
#>
#>
#> N series:
#> 1
#>
#> N timepoints:
#> 100
#>
#> Status:
#> Fitted using Stan
#> 4 chains, each with iter = 1000; warmup = 500; thin = 1
#> Total post-warmup draws = 2000
#>
#>
#> GAM coefficient (beta) estimates:
#> 2.5% 50% 97.5% Rhat n_eff
#> (Intercept) 0.93 1.2 1.4 1 793
#>
#> Latent trend parameter AR estimates:
#> 2.5% 50% 97.5% Rhat n_eff
#> ar1[1] 0.46 0.79 0.99 1 1074
#> sigma[1] 0.38 0.51 0.68 1 793
#>
#> Stan MCMC diagnostics:
#> n_eff / iter looks reasonable for all parameters
#> Rhat looks reasonable for all parameters
#> 0 of 2000 iterations ended with a divergence (0%)
#> 0 of 2000 iterations saturated the maximum tree depth of 10 (0%)
#> E-FMI indicated no pathological behavior
#>
#> Samples were drawn using NUTS(diag_e) at Sat Nov 16 8:48:27 AM 2024.
#> For each parameter, n_eff is a crude measure of effective sample size,
#> and Rhat is the potential scale reduction factor on split MCMC chains
#> (at convergence, Rhat = 1)
mcmc_plot(mod2, type = 'neff_hist')
#> `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
mcmc_plot(mod2,
variable = c('ar1', 'sigma', 'Intercept'),
regex = TRUE,
type = 'hist')
#> `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
plot(mod2, type = 'forecast')
Created on 2024-11-16 with reprex v2.0.2