Performing posterior predictive checks for variational inference models

Hi all,

Apologies in advance for the long post - feel free to jump to the Questions section at the end.

Background


Over the past year or so I’ve been working on developing a hierarchical approximate Gaussian process (GP) model for spatial transcriptomics datasets. The data are composed of an n_s \times n_g counts matrix that has been normalized and scaled, and an n_s \times 2 coordinates matrix that has also been scaled. The subscripts s and g correspond to “spots” (small groups of individual cells) and genes, respectively. The goal of the model is to flexibly & accurately identify genes whose expression exhibits strong spatial dependence i.e., if you visualized their expression on the coordinates you would see distinct patterns such as those shown below. Such genes are called spatially variable genes (SVGs).

Methodology


The math bit

The scaled, normalized expression of gene g at spot i follows a Gaussian likelihood defined like so:

\tilde{e}_{ig} \sim \text{Gaussian}(\mu_{ig},\: \sigma^2_{\tilde{e}})

where the mean is defined as:

\mu_{ig} = \beta_{0_g} + \tau_g^2 \tilde{\phi}(s_i)^\intercal \alpha_g

In the above equation \beta_{0_g} represents the gene-specific intercept, \tau_g is the gene-specific amplitude (or marginal SD) of the GP, \tilde{\phi}(s_i) is the i^{\text{th}} row of the n_s \times k matrix of orthonormal basis functions used to capture the underlying spatial structure, and \alpha_g is the corresponding vector of gene-specific basis function coefficients.

In addition, by default we estimate the j = 1, \dots, k basis functions using the exponentiated quadratic kernel:

\phi_j(s_i) = \exp\left( - \frac{||(x_i, y_i) - c_j||^2}{2\hat{\ell}^2}\right)

where c_j is the j^{\text{th}} centroid obtained after running k-means clustering on the scaled coordinates. The estimated global length-scaled of the GP \hat{\ell} is taken to be the median of the Euclidean distances between the k cluster centroids.

The code bit

All of this is implemented in Stan (with non-centered priors for the gene-specific intercepts and basis function coefficients).

data {
  int<lower=1> M;  // number of spots
  int<lower=1> N;  // number of gene-spot pairs in long dataframe
  int<lower=1> G;  // number of genes
  int<lower=1> k;  // number of basis functions used to approximate GP
  array[N] int<lower=1, upper=M> spot_id;  // unique ID for each spot
  array[N] int<lower=1, upper=G> gene_id;  // unique ID for each gene
  matrix[M, k] phi;  // matrix of QR-decomposed basis functions used to approximate GP
  vector[N] y;  // vector of normalized, scaled gene expression used as response variable
}

parameters {
  real mu_beta0;  // mean for the global intercepts
  real<lower=0> sigma_beta0;  // SD for the global intercepts 
  vector[G] z_beta0;  // vector of standard normal RVs for the global intercepts 
  real mu_amplitude;  // mean for the amplitude
  real<lower=0> sigma_amplitude;  // SD for the amplitude
  vector<lower=0>[G] amplitude;  // vector of gene-specific amplitudes of the approximate GP
  vector[k] mu_alpha;  // vector of means for the basis function coefficients
  vector<lower=0>[k] sigma_alpha;  // vector of SDs for the basis function coefficients
  matrix[k, G] z_alpha_t;  // standard normal RV for basis function coefficients
  real<lower=0> sigma_y;  // observation noise of response variable
}

transformed parameters {
  vector[G] beta0 = mu_beta0 + sigma_beta0 * z_beta0;
  vector[G] amplitude_sq = square(amplitude);
}

model {
  matrix[k, G] alpha_t;
  alpha_t = rep_matrix(mu_alpha, G) + diag_pre_multiply(sigma_alpha, z_alpha_t);
  matrix[M, G] phi_alpha = phi * alpha_t;
  vector[N] w;
  for (i in 1:N) {
    w[i] = phi_alpha[spot_id[i], gene_id[i]];
  }
  to_vector(z_alpha_t) ~ std_normal();
  z_beta0 ~ std_normal();
  mu_beta0 ~ normal(0, 2);
  sigma_beta0 ~ std_normal();
  mu_amplitude ~ normal(0, 2);
  sigma_amplitude ~ std_normal();
  amplitude ~ lognormal(mu_amplitude, sigma_amplitude);
  mu_alpha ~ normal(0, 2);
  sigma_alpha ~ std_normal();
  sigma_y ~ std_normal();
  y ~ normal(beta0[gene_id] + amplitude_sq[gene_id] .* w, sigma_y);
}

After formatting the spatial transcriptomics data such that it matches the structure of the model, I use cmdstanr to perform the fitting. First, I run a quick optimization to initialize all of the model parameters near their modes, as I’ve seen empirically that this increases convergence speed for the next step - fitting the model using variational inference (VI). Here’s what the full process looks like:

model_init <- mod$optimize(data_list,
                           seed = random.seed,  # defaults to 312 
                           init = 0,
                           opencl_ids = opencl_IDs,  # defaults to NULL
                           jacobian = FALSE,
                           iter = mle.iter,  # defaults to 1000
                           algorithm = "lbfgs",
                           history_size = 25L)
fit_vi <- mod$variational(data_list,
                          seed = random.seed,  # defaults to 312 
                          init = model_init,
                          algorithm = algorithm,  # defaults to "meanfield"
                          iter =  n.iter,  # defaults to 3000
                          draws = n.draws,  # defaults to 1000
                          opencl_ids = opencl_IDs,  # defaults to NULL
                          elbo_samples = elbo.samples)  # defaults to 150

Finally, I use the posterior package to generate (by default) 1000 draws from the variational posterior of \hat{\tau}_g for every gene g. Taking the mean of the draws gives a point estimate of spatial variability for every gene that can then be used to rank the genes with e.g., the top 1000 genes by estimated amplitude being designated as SVGs.

All of this is implemented in a much more user-friendly manner in my R package bayesVG.

Challenges


  • High dimensionality - typically the number of genes n_g \in [2000, 4000]
  • Scalability - depending on the technology used the sequence the raw tissue, the number of observations n_s can range from from the thousands (e.g., 10X Visium) to the hundreds of thousands (e.g., 10X Xenium)
  • Interpretability - whatever parameter estimate used to rank & select SVGs should have an interpretation grounded in the underlying biology
  • Accuracy - the variational posterior might not be a good enough approximation to the true posterior

The last point (accuracy) is the one I’m still working on and am currently concerned about.

Questions


I’ve read a decent amount about how to perform posterior predictive checks (PPCs) for models fit using classical MCMC sampling such as, for example, the official Stan documention and the documentation for the loo package used to perform leave-one-out cross-validation. However, neither resource has examples or information related to models fit using VI.

I also read this great preprint on evaluating the quality of the variational approximation, but I’m not quite sure how to take the concepts there and turn them into working R + Stan code. I know I need to add a generated quantities block to the Stan code, I just don’t know what I need to put there or how to process it after fitting.

Any hints on how to approach this would be greatly appreciated !

-Jack

2 Likes

The first step would be to compare posterior predictive replicates to the data. Assuming you don’t want the variables you defined in the model block to the output, then add to the Stan code the following

generated quantities {
  vector[N] y_rep;
  {
    matrix[k, G] alpha_t;
    alpha_t = rep_matrix(mu_alpha, G) + diag_pre_multiply(sigma_alpha, z_alpha_t);
    matrix[M, G] phi_alpha = phi * alpha_t;
    vector[N] w;
    for (i in 1:N) {
      w[i] = phi_alpha[spot_id[i], gene_id[i]];
    }
    y_rep = normal_rng(beta0[gene_id] + amplitude_sq[gene_id] .* w, sigma_y);
}

Then you can use various bayesplot package visualtions (see also Visualization in Bayesian workflow). These work just fine with any approximate inference.

Unfortunately, ADVI algorithm in Stan is severely outdated, but it is unlikely that faster and more stable algorithms will be available soon in Stan. You might consider using BridgeStan and some external algorithm.

You can also try whether Pathfinder (Pathfinder: Parallel quasi-Newton variational inference) would be sufficiently good. It would be much faster than ADVI, and may be accurate enough, and it’s available in Stan (use $pathfinder(..., num_paths=10, single_path_draws=40, draws=400, history_size=100, max_lbfgs_iters=100))

Fast leave-one-out cross-validation for ADVI is discussed in Bayesian leave-one-out cross-validation for large data and Leave-One-Out Cross-Validation for Bayesian Model Comparison in Large Data, but due high dimensionality of the posterior it’s likely to fail. It’s better to use K-fold-CV in this case.

I could provide the code, but again due to the high dimensionality, the approach presented in that paper is likely to be not useful.

Instead I recommend looking at [2502.03279] Posterior SBC: Simulation-Based Calibration Checking Conditional on Data, which would work, but likely requires running the inference 100 times.

Thank you for the detailed info ! Apologies for the late response - I was wrapping up the semester & traveling home for the holidays.

I implemented your suggested syntax for Stan’s generated quantities block like so (it’s a bit different because I had to fix some compilation errors):

generated quantities {
  vector[N] y_rep;
  {
    matrix[k, G] alpha_t;
    alpha_t = rep_matrix(mu_alpha, G) + diag_pre_multiply(sigma_alpha, z_alpha_t);
    matrix[M, G] phi_alpha = phi * alpha_t;
    for (i in 1:N) {
      real w_i = phi_alpha[spot_id[i], gene_id[i]];
      y_rep[i] = normal_rng(beta0[gene_id[i]] + amplitude_sq[gene_id[i]] * w_i, sigma_y);
    }
  }
}

Then in R (post-fitting of the VI model, denoted fit_vi):

fit_gq <- mod$generate_quantities(fit_vi,
                                  data = data_list,
                                  seed = 312)
y_rep <- fit_gq$draws("y_rep", format = "draws_matrix")

I’ll note first that while the VI step completed pretty quickly for n_g = 1000 genes and n_s = 2444 spots (so overall n = 2444000), the GQ step has been hanging & my RAM usage according to RStudio has been between 10-12GB. I assume this is due to the value of n being greater than 2 million. Finally, even if the GQ step were to finish successfully, I severely doubt that I could load a n_{\text{draws}} \times n matrix having 2.444 million columns into memory.

Is there any way for me to aggregate / compute PPC statistics in Stan prior to loading the results into memory in R ? For context, I’m (by far) most interested in performing PPCs for the approximate posterior of the per-gene amplitude parameter \hat{\tau}_g; at this point I’m much less concerned with the posteriors for the basis function coefficients or the intercepts.

Thanks,
Jack

Right, I forgot your N is so big.

Yes, you can compute the test statistics for each posterior draw in generated quantities, and then instead of SxN (where S is the number of posterior draws) you would return SxT (where T is the number of test statistics). We have recently recommended visualization and probability integral transformation (PIT) checks, but they would require SxN memory. The downside of scalar test statistics is that you need to be careful to select one that is as ancillary as possible, so you need to think about what statistic is not one-to-one mapped to one of the model parameters, tells you something useful about the data distribution, and is fast to compute for N replicates.

PPC refers to posterior predictive checking, and is used to examine how well the posterior predictive distribution matches the data. It’s not directly assessing how good the approximate posterior is compared to the true posterior.

It seems using posterior SBC with VI or comparing to MCMC results, would be the best choices for you even they are computationally costly.

Ah OK, thank you for correcting my interpretation of what a PPC is. I’ll do some thinking about what scalar statistic(s) might be worth implementing as well as looking more into posterior SBC (I found the preprint you put out on it earlier this year & am reading through it now).