Mathematical notation for a zero-inflated negative binomial nested model

Follow-up to the post: Mathematical Notation for a zero inflated negative binomial model in brms
which was explained in detail by @Max_Mantei (thanks for your excellent explanation!)

I have a slightly more complex scenario, with a nested design (e.g. pupils and subjects), and four numerical predictors. One of the predictors (A) is a pure population-level effect, i.e. I don’t foresee any difference between pupils and subjects.
The other three predictors, and the intercept, is expected to vary among groups.
I use a nested design (but specified manually, because I would like to learn what I am doing), because I do expect pupils to have certain “base behaviour”, regardless of subject. But then, of course, pupils have certain behaviour per subject as well.

I use the following brms model:

formula <- bf(y ~ 1 + A + R + C + D + (1 + R + C + D | pupil) + (1 + R + C + D | pupil:subject),
              zi ~ 1 + A + R + C + D + (1 + R + C + D | pupil) + (1 + R + C + D | pupil:subject))

Number of pupils are 11 and number of subjects are 8. But not all pupils take all subjects, so pupil:subject have 84 levels, not 88 as if it would have been completely balanced.

The numerical predictors are log-transformed, and the resulting logarithm is centered and scaled. So all expected values are 0. The total number of observations are about 31000.

Because I have quite a lot of zeros in my data (>90%, which is to be expected), and only expect outcomes ranging into ~1000 to be reasonable (and expect much less on average, Q99 would be about 20), I had to fiddle a bit with my priors.
I found these priors to make reasonably realistic predictions:

priors <- c(prior(normal(0, 0.5), class = Intercept),
            prior(normal(0, 0.25), class = b),
            prior(weibull(2, 0.25), class = sd),
            prior(weibull(2, 0.25), class = sd, group=team:repo),
            prior(lkj(2), class = cor),
            prior(normal(0, 0.5), class = Intercept, dpar=zi),
            prior(normal(0, 0.25), class = b, dpar=zi),
            prior(weibull(2, 0.25), class = sd, dpar=zi),
            prior(weibull(2, 0.25), class = sd, group=team:repo, dpar=zi),
            prior(gamma(0.1, 0.1), class = shape))

I use the standard log link for the \mu of the negative binomial, and logit link for the \xi of the zero-inflation part.

Here is the stan code so far:

// generated with brms 2.20.4
functions {
 /* compute correlated group-level effects
  * Args:
  *   z: matrix of unscaled group-level effects
  *   SD: vector of standard deviation parameters
  *   L: cholesky factor correlation matrix
  * Returns:
  *   matrix of scaled group-level effects
  */
  matrix scale_r_cor(matrix z, vector SD, matrix L) {
    // r is stored in another dimension order than z
    return transpose(diag_pre_multiply(SD, L) * z);
  }
  /* zero-inflated negative binomial log-PDF of a single response
   * Args:
   *   y: the response value
   *   mu: mean parameter of negative binomial distribution
   *   phi: shape parameter of negative binomial distribution
   *   zi: zero-inflation probability
   * Returns:
   *   a scalar to be added to the log posterior
   */
  real zero_inflated_neg_binomial_lpmf(int y, real mu, real phi,
                                       real zi) {
    if (y == 0) {
      return log_sum_exp(bernoulli_lpmf(1 | zi),
                         bernoulli_lpmf(0 | zi) +
                         neg_binomial_2_lpmf(0 | mu, phi));
    } else {
      return bernoulli_lpmf(0 | zi) +
             neg_binomial_2_lpmf(y | mu, phi);
    }
  }
  /* zero-inflated negative binomial log-PDF of a single response
   * logit parameterization of the zero-inflation part
   * Args:
   *   y: the response value
   *   mu: mean parameter of negative binomial distribution
   *   phi: shape parameter of negative binomial distribution
   *   zi: linear predictor for zero-inflation part
   * Returns:
   *   a scalar to be added to the log posterior
   */
  real zero_inflated_neg_binomial_logit_lpmf(int y, real mu,
                                             real phi, real zi) {
    if (y == 0) {
      return log_sum_exp(bernoulli_logit_lpmf(1 | zi),
                         bernoulli_logit_lpmf(0 | zi) +
                         neg_binomial_2_lpmf(0 | mu, phi));
    } else {
      return bernoulli_logit_lpmf(0 | zi) +
             neg_binomial_2_lpmf(y | mu, phi);
    }
  }
  /* zero-inflated negative binomial log-PDF of a single response
   * log parameterization for the negative binomial part
   * Args:
   *   y: the response value
   *   eta: linear predictor for negative binomial distribution
   *   phi: shape parameter of negative binomial distribution
   *   zi: zero-inflation probability
   * Returns:
   *   a scalar to be added to the log posterior
   */
  real zero_inflated_neg_binomial_log_lpmf(int y, real eta,
                                           real phi, real zi) {
    if (y == 0) {
      return log_sum_exp(bernoulli_lpmf(1 | zi),
                         bernoulli_lpmf(0 | zi) +
                         neg_binomial_2_log_lpmf(0 | eta, phi));
    } else {
      return bernoulli_lpmf(0 | zi) +
             neg_binomial_2_log_lpmf(y | eta, phi);
    }
  }
  /* zero-inflated negative binomial log-PDF of a single response
   * log parameterization for the negative binomial part
   * logit parameterization of the zero-inflation part
   * Args:
   *   y: the response value
   *   eta: linear predictor for negative binomial distribution
   *   phi: shape parameter of negative binomial distribution
   *   zi: linear predictor for zero-inflation part
   * Returns:
   *   a scalar to be added to the log posterior
   */
  real zero_inflated_neg_binomial_log_logit_lpmf(int y, real eta,
                                                 real phi, real zi) {
    if (y == 0) {
      return log_sum_exp(bernoulli_logit_lpmf(1 | zi),
                         bernoulli_logit_lpmf(0 | zi) +
                         neg_binomial_2_log_lpmf(0 | eta, phi));
    } else {
      return bernoulli_logit_lpmf(0 | zi) +
             neg_binomial_2_log_lpmf(y | eta, phi);
    }
  }
  // zero_inflated negative binomial log-CCDF and log-CDF functions
  real zero_inflated_neg_binomial_lccdf(int y, real mu, real phi, real hu) {
    return bernoulli_lpmf(0 | hu) + neg_binomial_2_lccdf(y | mu, phi);
  }
  real zero_inflated_neg_binomial_lcdf(int y, real mu, real phi, real hu) {
    return log1m_exp(zero_inflated_neg_binomial_lccdf(y | mu, phi, hu));
  }
}
data {
  int<lower=1> N;  // total number of observations
  array[N] int Y;  // response variable
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
  int<lower=1> Kc;  // number of population-level effects after centering
  int<lower=1> K_zi;  // number of population-level effects
  matrix[N, K_zi] X_zi;  // population-level design matrix
  int<lower=1> Kc_zi;  // number of population-level effects after centering
  // data for group-level effects of ID 1
  int<lower=1> N_1;  // number of grouping levels
  int<lower=1> M_1;  // number of coefficients per level
  array[N] int<lower=1> J_1;  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_1_1;
  vector[N] Z_1_2;
  vector[N] Z_1_3;
  vector[N] Z_1_4;
  int<lower=1> NC_1;  // number of group-level correlations
  // data for group-level effects of ID 2
  int<lower=1> N_2;  // number of grouping levels
  int<lower=1> M_2;  // number of coefficients per level
  array[N] int<lower=1> J_2;  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_2_1;
  vector[N] Z_2_2;
  vector[N] Z_2_3;
  vector[N] Z_2_4;
  int<lower=1> NC_2;  // number of group-level correlations
  // data for group-level effects of ID 3
  int<lower=1> N_3;  // number of grouping levels
  int<lower=1> M_3;  // number of coefficients per level
  array[N] int<lower=1> J_3;  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_3_zi_1;
  vector[N] Z_3_zi_2;
  vector[N] Z_3_zi_3;
  vector[N] Z_3_zi_4;
  int<lower=1> NC_3;  // number of group-level correlations
  // data for group-level effects of ID 4
  int<lower=1> N_4;  // number of grouping levels
  int<lower=1> M_4;  // number of coefficients per level
  array[N] int<lower=1> J_4;  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_4_zi_1;
  vector[N] Z_4_zi_2;
  vector[N] Z_4_zi_3;
  vector[N] Z_4_zi_4;
  int<lower=1> NC_4;  // number of group-level correlations
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
  matrix[N, Kc] Xc;  // centered version of X without an intercept
  vector[Kc] means_X;  // column means of X before centering
  matrix[N, Kc_zi] Xc_zi;  // centered version of X_zi without an intercept
  vector[Kc_zi] means_X_zi;  // column means of X_zi before centering
  for (i in 2:K) {
    means_X[i - 1] = mean(X[, i]);
    Xc[, i - 1] = X[, i] - means_X[i - 1];
  }
  for (i in 2:K_zi) {
    means_X_zi[i - 1] = mean(X_zi[, i]);
    Xc_zi[, i - 1] = X_zi[, i] - means_X_zi[i - 1];
  }
}
parameters {
  vector[Kc] b;  // regression coefficients
  real Intercept;  // temporary intercept for centered predictors
  real<lower=0> shape;  // shape parameter
  vector[Kc_zi] b_zi;  // regression coefficients
  real Intercept_zi;  // temporary intercept for centered predictors
  vector<lower=0>[M_1] sd_1;  // group-level standard deviations
  matrix[M_1, N_1] z_1;  // standardized group-level effects
  cholesky_factor_corr[M_1] L_1;  // cholesky factor of correlation matrix
  vector<lower=0>[M_2] sd_2;  // group-level standard deviations
  matrix[M_2, N_2] z_2;  // standardized group-level effects
  cholesky_factor_corr[M_2] L_2;  // cholesky factor of correlation matrix
  vector<lower=0>[M_3] sd_3;  // group-level standard deviations
  matrix[M_3, N_3] z_3;  // standardized group-level effects
  cholesky_factor_corr[M_3] L_3;  // cholesky factor of correlation matrix
  vector<lower=0>[M_4] sd_4;  // group-level standard deviations
  matrix[M_4, N_4] z_4;  // standardized group-level effects
  cholesky_factor_corr[M_4] L_4;  // cholesky factor of correlation matrix
}
transformed parameters {
  matrix[N_1, M_1] r_1;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[N_1] r_1_1;
  vector[N_1] r_1_2;
  vector[N_1] r_1_3;
  vector[N_1] r_1_4;
  matrix[N_2, M_2] r_2;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[N_2] r_2_1;
  vector[N_2] r_2_2;
  vector[N_2] r_2_3;
  vector[N_2] r_2_4;
  matrix[N_3, M_3] r_3;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[N_3] r_3_zi_1;
  vector[N_3] r_3_zi_2;
  vector[N_3] r_3_zi_3;
  vector[N_3] r_3_zi_4;
  matrix[N_4, M_4] r_4;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[N_4] r_4_zi_1;
  vector[N_4] r_4_zi_2;
  vector[N_4] r_4_zi_3;
  vector[N_4] r_4_zi_4;
  real lprior = 0;  // prior contributions to the log posterior
  // compute actual group-level effects
  r_1 = scale_r_cor(z_1, sd_1, L_1);
  r_1_1 = r_1[, 1];
  r_1_2 = r_1[, 2];
  r_1_3 = r_1[, 3];
  r_1_4 = r_1[, 4];
  // compute actual group-level effects
  r_2 = scale_r_cor(z_2, sd_2, L_2);
  r_2_1 = r_2[, 1];
  r_2_2 = r_2[, 2];
  r_2_3 = r_2[, 3];
  r_2_4 = r_2[, 4];
  // compute actual group-level effects
  r_3 = scale_r_cor(z_3, sd_3, L_3);
  r_3_zi_1 = r_3[, 1];
  r_3_zi_2 = r_3[, 2];
  r_3_zi_3 = r_3[, 3];
  r_3_zi_4 = r_3[, 4];
  // compute actual group-level effects
  r_4 = scale_r_cor(z_4, sd_4, L_4);
  r_4_zi_1 = r_4[, 1];
  r_4_zi_2 = r_4[, 2];
  r_4_zi_3 = r_4[, 3];
  r_4_zi_4 = r_4[, 4];
  lprior += normal_lpdf(b | 0, 0.25);
  lprior += normal_lpdf(Intercept | 0, 0.5);
  lprior += gamma_lpdf(shape | 0.1, 0.1);
  lprior += normal_lpdf(b_zi | 0, 0.25);
  lprior += normal_lpdf(Intercept_zi | 0, 0.5);
  lprior += weibull_lpdf(sd_1 | 2, 0.25);
  lprior += lkj_corr_cholesky_lpdf(L_1 | 2);
  lprior += weibull_lpdf(sd_2 | 2, 0.25);
  lprior += lkj_corr_cholesky_lpdf(L_2 | 2);
  lprior += weibull_lpdf(sd_3 | 2, 0.25);
  lprior += lkj_corr_cholesky_lpdf(L_3 | 2);
  lprior += weibull_lpdf(sd_4 | 2, 0.25);
  lprior += lkj_corr_cholesky_lpdf(L_4 | 2);
}
model {
  // likelihood including constants
  if (!prior_only) {
    // initialize linear predictor term
    vector[N] mu = rep_vector(0.0, N);
    // initialize linear predictor term
    vector[N] zi = rep_vector(0.0, N);
    mu += Intercept + Xc * b;
    zi += Intercept_zi + Xc_zi * b_zi;
    for (n in 1:N) {
      // add more terms to the linear predictor
      mu[n] += r_1_1[J_1[n]] * Z_1_1[n] + r_1_2[J_1[n]] * Z_1_2[n] + r_1_3[J_1[n]] * Z_1_3[n] + r_1_4[J_1[n]] * Z_1_4[n] + r_2_1[J_2[n]] * Z_2_1[n] + r_2_2[J_2[n]] * Z_2_2[n] + r_2_3[J_2[n]] * Z_2_3[n] + r_2_4[J_2[n]] * Z_2_4[n];
    }
    for (n in 1:N) {
      // add more terms to the linear predictor
      zi[n] += r_3_zi_1[J_3[n]] * Z_3_zi_1[n] + r_3_zi_2[J_3[n]] * Z_3_zi_2[n] + r_3_zi_3[J_3[n]] * Z_3_zi_3[n] + r_3_zi_4[J_3[n]] * Z_3_zi_4[n] + r_4_zi_1[J_4[n]] * Z_4_zi_1[n] + r_4_zi_2[J_4[n]] * Z_4_zi_2[n] + r_4_zi_3[J_4[n]] * Z_4_zi_3[n] + r_4_zi_4[J_4[n]] * Z_4_zi_4[n];
    }
    for (n in 1:N) {
      target += zero_inflated_neg_binomial_log_logit_lpmf(Y[n] | mu[n], shape, zi[n]);
    }
  }
  // priors including constants
  target += lprior;
  target += std_normal_lpdf(to_vector(z_1));
  target += std_normal_lpdf(to_vector(z_2));
  target += std_normal_lpdf(to_vector(z_3));
  target += std_normal_lpdf(to_vector(z_4));
}
generated quantities {
  // actual population-level intercept
  real b_Intercept = Intercept - dot_product(means_X, b);
  // actual population-level intercept
  real b_zi_Intercept = Intercept_zi - dot_product(means_X_zi, b_zi);
  // compute group-level correlations
  corr_matrix[M_1] Cor_1 = multiply_lower_tri_self_transpose(L_1);
  vector<lower=-1,upper=1>[NC_1] cor_1;
  // compute group-level correlations
  corr_matrix[M_2] Cor_2 = multiply_lower_tri_self_transpose(L_2);
  vector<lower=-1,upper=1>[NC_2] cor_2;
  // compute group-level correlations
  corr_matrix[M_3] Cor_3 = multiply_lower_tri_self_transpose(L_3);
  vector<lower=-1,upper=1>[NC_3] cor_3;
  // compute group-level correlations
  corr_matrix[M_4] Cor_4 = multiply_lower_tri_self_transpose(L_4);
  vector<lower=-1,upper=1>[NC_4] cor_4;
  // extract upper diagonal of correlation matrix
  for (k in 1:M_1) {
    for (j in 1:(k - 1)) {
      cor_1[choose(k - 1, 2) + j] = Cor_1[j, k];
    }
  }
  // extract upper diagonal of correlation matrix
  for (k in 1:M_2) {
    for (j in 1:(k - 1)) {
      cor_2[choose(k - 1, 2) + j] = Cor_2[j, k];
    }
  }
  // extract upper diagonal of correlation matrix
  for (k in 1:M_3) {
    for (j in 1:(k - 1)) {
      cor_3[choose(k - 1, 2) + j] = Cor_3[j, k];
    }
  }
  // extract upper diagonal of correlation matrix
  for (k in 1:M_4) {
    for (j in 1:(k - 1)) {
      cor_4[choose(k - 1, 2) + j] = Cor_4[j, k];
    }
  }
}

What is the correct way to describe this model?
I tried following the example above, and came up with something like:

\begin{aligned} p(y_i | \mu_i, \phi) &= \begin{cases} \xi_i + (1-\xi_i)\cdot \text{NegativeBinomial}(0 | \mu_i, \phi) & \text{if } y_i = 0\\ (1-\xi_i)\cdot\text{NegativeBinomial}(y_i | \mu_i, \phi) & \text{if } y_i \neq 0 \\ \end{cases} \\ \log(\mu_i) & = \beta_{0,i} + \beta_A A_i + \sum_{P \in \{R,C,D\}} \beta_P P_i \\ \text{logit}(\xi_i) & = \gamma_{0,i} + \gamma_A A_i + \sum_{P \in \{R,C,D\}} \gamma_P P_i \\ \forall P \in \{0, R, C, D\}: \beta_P & = \beta_{P,pop} + \beta_{P,pupil[i]} + \beta_{P,pupil:subject[i]} \\ \forall P \in \{0, R, C, D\}: \gamma_P & = \gamma_{P,pop} + \gamma_{P,pupil[i]} + \gamma_{P,pupil:subject[i]} \\ \beta_{0,pop} & \sim \text{Normal}(0,0.5) \\ \beta_{A} & \sim \text{Normal}(0,0.25) \\ \forall P \in \{R, C, D\}: \beta_{P,pop} & \sim \text{Normal}(0,0.25) \\ \forall P \in \{R, C, D\}: \beta_{P,pupil[i]} & \sim \text{MVN}(0,\Sigma_{P,pupil}) \\ \forall P \in \{R, C, D\}: \Sigma_{P,pupil} & \sim \text{diag}(\sigma_{pupil})\Omega_{pupil}\text{diag}(\sigma_{pupil}) \\ \sigma_{pupil} &\sim \text{Weibull}(2, 0.25) \\ \Omega_{pupil} &\sim \text{LKJ}(2) \\ \phi &\sim \text{Gamma}(0.1, 0.1) \end{aligned}


And similar notation for pupil:subject and for the zero-inflation \xi_i (which then use the \gamma predictors.
Does this make sense at all? I’m in particular unsure about the MVN part of the group-level predictors. As I understand brms, group-level betas are modeled as offsets from a general population-level beta. Meaning that the expected value of the group-level beta is defined to be 0, and the standard deviation is what you use as a prior (at least that is what brms validate_priors tell me.
But I’m also a bit unsure how to use the fact that the population-level intercept and betas also should have their cholesky matrix (as the intercept and the betas will be correlated also in the general population). I guess I should report that as well, with a corresponding \Sigma, right?


For my own learning purposes, here is the output from validate_prior:

                prior     class      coef     group resp dpar nlpar lb ub       source
      normal(0, 0.25)         b                                                   user
      normal(0, 0.25)         b         A                                 (vectorized)
      normal(0, 0.25)         b         C                                 (vectorized)
      normal(0, 0.25)         b         D                                 (vectorized)
      normal(0, 0.25)         b         R                                 (vectorized)
      normal(0, 0.25)         b                            zi                     user
      normal(0, 0.25)         b         A                  zi             (vectorized)
      normal(0, 0.25)         b         C                  zi             (vectorized)
      normal(0, 0.25)         b         D                  zi             (vectorized)
      normal(0, 0.25)         b         R                  zi             (vectorized)
       normal(0, 0.5) Intercept                                                   user
       normal(0, 0.5) Intercept                            zi                     user
 lkj_corr_cholesky(2)         L                                                   user
 lkj_corr_cholesky(2)         L                pupil                     (vectorized)
 lkj_corr_cholesky(2)         L           pupil:subj                       (vectorized)
     weibull(2, 0.25)        sd                                      0            user
     weibull(2, 0.25)        sd                            zi        0            user
     weibull(2, 0.25)        sd                pupil                 0    (vectorized)
     weibull(2, 0.25)        sd         C      pupil                 0    (vectorized)
     weibull(2, 0.25)        sd         D      pupil                 0    (vectorized)
     weibull(2, 0.25)        sd Intercept      pupil                 0    (vectorized)
     weibull(2, 0.25)        sd         R      pupil                 0    (vectorized)
     weibull(2, 0.25)        sd                pupil       zi        0    (vectorized)
     weibull(2, 0.25)        sd         C      pupil       zi        0    (vectorized)
     weibull(2, 0.25)        sd         D      pupil       zi        0    (vectorized)
     weibull(2, 0.25)        sd Intercept      pupil       zi        0    (vectorized)
     weibull(2, 0.25)        sd         R      pupil       zi        0    (vectorized)
     weibull(2, 0.25)        sd           pupil:subj                 0            user
     weibull(2, 0.25)        sd         C pupil:subj                 0    (vectorized)
     weibull(2, 0.25)        sd         D pupil:subj                 0    (vectorized)
     weibull(2, 0.25)        sd Intercept pupil:subj                 0    (vectorized)
     weibull(2, 0.25)        sd         R pupil:subj                 0    (vectorized)
     weibull(2, 0.25)        sd           pupil:subj       zi        0            user
     weibull(2, 0.25)        sd         C pupil:subj       zi        0    (vectorized)
     weibull(2, 0.25)        sd         D pupil:subj       zi        0    (vectorized)
     weibull(2, 0.25)        sd Intercept pupil:subj       zi        0    (vectorized)
     weibull(2, 0.25)        sd         R pupil:subj       zi        0    (vectorized)
      gamma(0.1, 0.1)     shape                                      0            user