Improving efficiency of latent autoregressive model

Hi Discourse,
I am fitting a models to a range of datasets in which I need to estimate latent autoregressive time series to contribute to the linear predictor for discrete observation models (Poisson and Negative Binomial). The time series of discrete observations are supplied as a flattened (column major order) array so that I can use the built-in glm functions in Stan. The code works well and I’ve tried to incorporate slicing where possible, but I’m wondering if there are any other ways to speed up the sampling time as the computational burden is becoming large when I scale up to 15 - 20 series and add other complexities to the linear predictors. The Stan model code and R code for creating a reproducible example are shown below.

#### Function to simulate Poisson observations over a latent real-valued autoregressive
# time series ####
sim_ar_discrete = function(n_series = 4, trend_model = 'AR1', N = 100){
  
  # Simulate AR parameters
  if(trend_model == 'AR1'){
    ar1s <- rnorm(n_series, sd = 0.3)
    ar2s <- rep(0, n_series)
    ar3s <- rep(0, n_series)
  }
  
  if(trend_model == 'AR2'){
    ar1s <- rnorm(n_series, sd = 0.3)
    ar2s <- rnorm(n_series, sd = 0.3)
    ar3s <- rep(0, n_series)
  }
  
  if(trend_model == 'AR3'){
    ar1s <- rnorm(n_series, sd = 0.3)
    ar2s <- rnorm(n_series, sd = 0.3)
    ar3s <- rnorm(n_series, sd = 0.3)
  }
  
  # Simulate intercept parameters
  alphas <- runif(n_series, -0.5, 1.5)
  
  # Function to simulate trends ahead using ar3 model
  sim_ar3 = function(ar1, ar2, ar3, N){
    states <- rep(NA, length = N + 3)
    inits <- cumsum(rnorm(3, 0, 0.1))
    states[1] <- inits[1]
    states[2] <- inits[2]
    states[3] <- inits[3]
    for (t in 4:(N + 3)) {
      states[t] <- rnorm(1, ar1*states[t - 1] +
                           ar2*states[t - 2] +
                           ar3*states[t - 3], 1)
    }
    states[-c(1:3)]
  }
  
  # Simulate latent real-valued trends
  trends <- do.call(cbind, lapply(seq_len(n_series), function(x){
    sim_ar3(ar1 = ar1s[x],
            ar2 = ar2s[x],
            ar3 = ar3s[x],
            N = N)
  }))
  
  # Simulate Poisson observatiosn
  obs_ys <- do.call(cbind, lapply(seq_len(n_series), function(x){
    rpois(N, exp(alphas[x] + trends[,x]))
  }))
  
  return(list(y = obs_ys,
              true_trends = trends,
              ar1s = ar1s,
              ar2s = ar2s,
              ar3s = ar3s,
              alphas = alphas))
}

# Simulate Poisson observations from independent AR2 models
data <- sim_ar_discrete(trend_model = 'AR2')

# Compile the model using Cmdstan
library(cmdstanr)
model_data <- list(n = NROW(data$y),
                   n_series = NCOL(data$y),
                   flat_ys = as.vector(data$y),
                   total_obs = length(data$y),
                   # Dummy coded design matrix (intercept terms only for now)
                   X = as.matrix(as.data.frame(model.matrix(~ sort(rep(paste0('series', 
                                                                              1:NCOL(data$y)), 
                                                                       NROW(data$y))) - 1))))

cmd_mod <- cmdstan_model(write_stan_file(stan_ar2_pois),
                         stanc_options = list('canonicalize=deprecations,braces,parentheses'))

# Condition the model on observed data
fit1 <- cmd_mod$sample(data = model_data,
                       chains = 4,
                       parallel_chains = 4,
                       refresh = 500,
                       init = mod$inits)

# Inference
library(bayesplot)
mcmc_hist(fit1$draws("ar1"))
data$ar1s

mcmc_hist(fit1$draws("ar2"))
data$ar2s

mcmc_hist(fit1$draws("alpha"))
data$alphas

where stan_ar2_pois is:

data {
  int<lower=0> n;                    // number of timepoints per series
  int<lower=0> n_series;             // number of series
  int<lower=0> total_obs;            // total number of observations
  int<lower=0> flat_ys[total_obs];   // flattened observations (glm compatability)
  matrix[total_obs, n_series] X;     // design matrix for linear predictor
}

parameters {
  // intercept parameters
  vector[n_series] alpha;
  
  // latent trend AR1 parameters
  vector<lower=-1.5,upper=1.5>[n_series] ar1;
  
  // latent trend AR2 parameters
  vector<lower=-1.5,upper=1.5>[n_series] ar2;
  
  // latent trend variance parameters
  vector<lower=0>[n_series] sigma;
  
  // latent trends
  matrix[n, n_series] trend;
}

model {
  // priors for intercept parameters
  alpha ~ normal(0, 2);
  
  // priors for AR parameters
  ar1 ~ normal(0, 0.5);
  ar2 ~ normal(0, 0.5);
  
  // priors for latent trend variance parameters
  sigma ~ exponential(1);
  
  // trend estimates
  trend[1, 1:n_series] ~ normal(0, sigma);
  trend[2, 1:n_series] ~ normal(trend[1, 1:n_series] * ar1, sigma);
  for(s in 1:n_series){
    trend[3:n, s] ~ normal(ar1[s] * trend[2:(n - 1), s] + ar2[s] * trend[1:(n - 2), s], sigma[s]);
  }
  
  // likelihood functions
  vector[total_obs] flat_trends;
  flat_trends = to_vector(trend);
  flat_ys ~ poisson_log_glm(append_col(X, flat_trends), 0.0, append_row(alpha, 1.0));
}

generated quantities {
  // posterior predictions
  array[n, n_series] int ypred;
  for(s in 1:n_series){ 
  ypred[1:n, s] = poisson_log_rng(alpha[s] + trend[1:n, s]);
  }
}

I’m most interested in whether this block:

  // trend estimates
  trend[1, 1:n_series] ~ normal(0, sigma);
  trend[2, 1:n_series] ~ normal(trend[1, 1:n_series] * ar1, sigma);
  for(s in 1:n_series){
    trend[3:n, s] ~ normal(ar1[s] * trend[2:(n - 1), s] + ar2[s] * trend[1:(n - 2), s], sigma[s]);
  }

can be made more efficient somehow, but I can’t really see a way with my limited experience. Any tips or tricks would be most helpful.