Help with correlated binomial random walk state space model

Model

I’m new to building models directly in Stan, and I’d be grateful for help implementing a state space model that has similarities to a simple two-party election forecasting model:

  1. there are multiple groups observed over time with binary counts, but not every group is observed at each timestep (i.e., there are missing data)
  2. the goal is to infer the latent state, Pr(success), for each group at all timesteps of the observation period and also forecast this latent state a modest number of timesteps into the future
  3. changes in the latent state from one timestep to the next occur via a correlated random walk process, set by a covariance matrix that is inferred from the data

At the core of the model are draws from a multivariate normal distribution, which requires a covariance matrix (\Sigma) that governs the state changes of each group. As I understand it, I can construct \Sigma by first creating a diagonal matrix (\mathbf{S}) from the vector of group standard deviations (\sigma_{n}) that determine the distribution of step sizes for each group. Then, pre- and post-multiplying this to the group correlation matrix (\mathbf{R}). Both distributions of \sigma_{n} and \mathbf{R} can be specified using priors. In other words,

\begin{aligned} \Sigma &= \mathbf{S} \mathbf{R} \mathbf{S} \\ \mathbf{S} &= \left(\begin{matrix} \sigma_{1} & 0 & \dots & 0 \\ 0 & \sigma_{2} & \ddots & \vdots \\ \vdots & \ddots & \ddots & 0 \\ 0 & \dots & 0 & \sigma_{n} \end{matrix}\right) \\ \sigma &\sim \mathrm{Gamma}(1, 1) \\ \mathbf{R} &\sim \text{LKJcorr}(1) \\ \end{aligned}

Problem

The model in which I’ve attempted to implement this covariance structure (code+data below) fails to sample due to the same error & exception on all chains. As an example:

Chain 1 Rejecting initial value:
Chain 1   Error evaluating the log probability at the initial value.
Chain 1 Exception: lkj_corr_lpdf: Correlation matrix is not a valid correlation matrix.
Correlation matrix(1,1) is 6.01864, but should be near 1.0
(in '/var/folders/04/jbnbgfsj0y537r7cr7180dx40000gq/T/Rtmp256Pzl/model-1077c528e1a8f.stan',
line 40, column 4 to column 34)

This refers to

   rw_corr ~ lkj_corr(corr0_eta);

in the model block. corr0_eta is a user-supplied shape parameter for the LKJ correlation distribution.

Questions

  1. What am I doing wrong to cause the error, and how do I fix it? Is it a simple parameterization error, bad priors, bad initial values, or is the covariance structure improperly implemented? I’ve seen similar models use a Cholesky decomposition for more complex correlation structures (e.g., reconciling local vs. national polls). I don’t understand whether I need to do this and (if so) how to set it up.
  2. Are there obvious structure/efficiency improvements that should be made? (code+data below) Since I’m encountering errors, I have no idea whether this code will do what I intend and sample efficiently. The key areas I’m uncertain about are:
  • handling of missing observations (using an indicator variable in the user-supplied data to control which observations update the likelihood in model)
  • handling of forecast timesteps (would it be better to separate this into a generated quantities block that is separated from the model?)
  • is there a more efficient way than for loops in the model block? I tried to extend the sliced missing data example from the Stan manual to this matrix, but couldn’t figure out how to pass the matrix columns to multi_normal in a vectorized way. I’d appreciate help doing this if it would improve efficiency.

reprex

I’m using the cmdstanr package v.0.4.0.9001 to fit the model with R 4.1.2 on an Apple M1 processor (aarch64).

Stan code: ssm_correlated-random-walk.stan (2.5 KB)
Example data: dat.Rdata (2.3 KB)
–or–
generate example data in R

library(MASS)        # for MVNormal distribution functions
library(rethinking)  # for LKJ correlation matrix distribution functions

set.seed(23981)
# data indices
n_time <- 50    # number of observation time steps
n_fc <- 5       # number of time steps ahead to forecast
n_group <- 10   # number of groups where state is observed
n <- (n_time + n_fc) * n_group            # total number of observations (incl. forecasts)
dat_obs <- as.integer(replicate(n_group,  # indicate data presence (~80% obs, 0% forecast)
                                rbinom(n_time + n_fc, size = 1,
                                                prob = c(rep(0.8, n_time), rep(0, n_fc)))))
# true parameter values (latent)
corr <- rethinking::rlkjcorr(n = 1, K = n_group, eta = 0.5)  # correlations
var <- rep(0.3, n_group)  # standard deviations
vcov <- diag(var) %*% corr %*% diag(var)
logit_true_p <- as.vector(apply(MASS::mvrnorm(n = n_time + n_fc,
                                              mu = rep(0, n_group),
                                              Sigma = vcov),
                                MARGIN = 2, cumsum))
true_p <- plogis(logit_true_p)
# observed data
time_group <- rep(1:(n_time + n_fc), times = n_group)
group <- factor(rep(letters[1:n_group], each = n_time + n_fc))
trials <- as.integer(replicate(n_group, c(sample(100:500, size = n_time, replace = TRUE),
                                          rep(9999, n_fc))))
successes <- rbinom(n, size = trials, prob = true_p)
trials[which(dat_obs==0)] <- 9999L  # delete missing cases
successes[which(trials==9999L)] <- 9999L
# prior parameters
corr0_eta <- 1                   # LKJ(_)
sigma0_shape <- rep(1, n_group)  # gamma(_,rate)
sigma0_rate <- rep(1, n_group)   # gamma(shape,_)
pr0_alpha <- rep(1, n_group)     # beta(_,beta)
pr0_beta <- rep(1, n_group)      # beta(alpha,_)

# list data to be passed to Stan
dat <- list(n = n, n_time = n_time, n_group = n_group, n_fc = n_fc, dat_obs = dat_obs,
            time_group = time_group, group = group, trials = trials, successes = successes,
            sigma0_shape = sigma0_shape, sigma0_rate = sigma0_rate,
            corr0_eta = corr0_eta, pr0_alpha = pr0_alpha, pr0_beta = pr0_beta)

R code to compile and sample the model

library(cmdstanr)
# compile the model
mod <- cmdstan_model("ssm_correlated-random-walk.stan")

# sample the model
fit <- mod$sample(data = dat)
1 Like

The surveil R package has correlated random walk models for count data (Poisson models); I don’t think the models in the package will fit your purpose but you might be able to take something from the Stan code, especially for making the covariance matrix working (efficiently). I think this page ‘1.13 Multivariate priors for hierarchical models’ from the Stan user’s guide was my reference.

1 Like

Thanks @cmcd! Your example for the correlated Poisson random walk had what I needed to get my model sampling.

I’m posting my code here in case it’s useful to others, or folks have ideas for improving efficiency. I’m getting divergences, and I will explore whether they’re coming from too weak priors or a deeper structural pathology in the model. The key changes from v1 are:

  • Reshaping the input data into matrix form (vs. original vector form). This also helped tidy up the data by getting rid of time and group indices, and I transposed my z matrix to follow your rows=time and cols=groups scheme.
  • Implementing the Cholesky factorization for the covariance matrix (this example also helpful)

ssm_correlated-random-walk_v2.stan (2.2 KB)

// ssm_correlated-random-walk_v2.stan
// correlated binomial random walk
data {
  // observed data & indices
  int TT;                        // number of observation timesteps
  int<lower=0> n_fc;             // number of timesteps ahead to forecast
  int n_group;                   // number of populations
  int successes[TT,n_group];     // matrix of successes
  int trials[TT,n_group];        // matrix of trials
  int dat_obs[TT,n_group];       // matrix with indicator of observed (1) or missing (0)

  // user-supplied priors
  real<lower=0> corr0_eta;              // group correlations: LKJ distribution
  real<lower=0> sigma0_shape[n_group];  // group random walk sd step size: gamma distribution (shape)
  real<lower=0> sigma0_rate[n_group];   // group random walk sd step size: gamma distribution (rate)
  real<lower=0> pr0_alpha[n_group];     // group starting pr(success): beta distribution [alpha]
  real<lower=0> pr0_beta[n_group];      // group starting pr(success): beta distribution [beta]
}
parameters {
  vector[n_group] z0;                      // initial values of latent probability (logit scale)
  vector[n_group] z[TT];                   // logit(pr(success)) for all populations & time steps
  vector<lower=0>[n_group] rw_sd;          // sd of random walk step sizes for each group
  cholesky_factor_corr[n_group] L_Omega;   // correlation in group random walk step sizes
}
transformed parameters {
  matrix[n_group, n_group] rw_L;           // Cholesky factorized correlation matrix
  rw_L = diag_pre_multiply(rw_sd, L_Omega);

  vector<lower=0,upper=1>[n_group] pr0;  // logit transform initial Pr(success) for beta prior
  pr0 = inv_logit(z0);
}
model {
  // PRIORS
  pr0 ~ beta(pr0_alpha, pr0_beta);
  rw_L ~ lkj_corr_cholesky(corr0_eta);
  rw_sd ~ gamma(sigma0_shape, sigma0_rate);

  // LATENT STATE MODEL
  z[1,:] ~ multi_normal_cholesky(z0, rw_L);  // set initial time step (col) for all groups (rows)
  for(t in 2:TT){
    z[t,:] ~ multi_normal_cholesky(z[t-1,:], rw_L);
  }

  // OBSERVATION MODEL
  for(t in 1:TT){
    for(g in 1:n_group){
      // evaluate likelihood function _only_ at timesteps that have an observation
      if (dat_obs[t,g] == 1) {
        successes[t,g] ~ binomial_logit(trials[t,g], z[t,g]);
      }
    }
  }
}

R code to simulate data:

library(MASS)
library(rethinking)  # remotes::install_github("rmcelreath/rethinking")
library(abind)

set.seed(23981)
# data indices
n_obs <- 50          # number of observation time steps
n_fc <- 5            # number of time steps ahead to forecast
TT <- n_obs + n_fc   # total number of time steps
n_group <- 10   # number of groups where state is observed
dat_obs <- replicate(n_group,  # indicate data presence (~80% obs, 0% forecast)
                     rbinom(TT, size = 1,
                            prob = c(rep(0.8, n_obs), rep(0, n_fc))))
# true parameter values (latent)
corr <- rethinking::rlkjcorr(n = 1, K = n_group, eta = 0.5)  # correlations
sd <- rep(0.3, n_group)  # standard deviations
covmat <- diag(sd) %*% corr %*% diag(sd)
logit_true_p <- apply(MASS::mvrnorm(n = TT, mu = rep(0, n_group),
                                    Sigma = vcov),
                      MARGIN = 2, cumsum)
true_p <- plogis(logit_true_p)
trials <- replicate(n_group, c(sample(100:500, size = n_obs, replace = TRUE),
                               rep(9999L, n_fc)))
trials[dat_obs == 0L] <- 9999L
successes <- apply(abind::abind(trials, true_p, along = 3), 1:2,
                   function(m, n, p){rbinom(1, size = m[n], prob = m[p])},
                   n = 1, p = 2)
successes[dat_obs == 0L] <- 9999L
# prior parameters
corr0_eta <- 1                   # LKJ(_)
sigma0_shape <- rep(1, n_group)  # gamma(_,rate)
sigma0_rate <- rep(0.5, n_group) # gamma(shape,_)
pr0_alpha <- rep(1, n_group)     # beta(_,beta)
pr0_beta <- rep(1, n_group)      # beta(alpha,_)

dat <- list(TT = TT, n_group = n_group, n_fc = n_fc, dat_obs = dat_obs,
             trials = trials, successes = successes,
             sigma0_shape = sigma0_shape, sigma0_rate = sigma0_rate,
             corr0_eta = corr0_eta, pr0_alpha = pr0_alpha, pr0_beta = pr0_beta)
1 Like