Parallelizing an approximate hierarchical Gaussian process with reduce_sum()

Hi all,

In short, I’m rather confused about how to implement within-chain parallelization for a somewhat complex model using reduce_sum() and partial_sum(). The examples I’ve found online ([1], [2], [3]) seem to be for pretty simple models, and I’m specifically looking to understand 1) how to correctly perform the slicing into chunks and 2) how to correctly pass parameters, data, and non-model parameters such as the grainsize to reduce_sum().

Background

I’ve posted here recently about my progress on fitting an approximate multivariate hierarchical Gaussian process (GP) model with the goal of identify spatially variable genes (SVGs) in spatial transcriptomics data. The data comes in 2 pieces - an n_{g} \times n_s gene-by-spot normalized counts matrix, and an n_s \times 2 spatial location matrix for each spot. The counts matrix is then converted to a long-format data.frame having n_s \times n_g rows and columns gene, spot, and gene_expression, where gene and spot are ID variables and gene_expression contains the normalized, scaled expression of each gene at each spot. I then approximate the GP as a Hilbert space using a basis function matrix \Phi via the Matérn 3/2 kernel, with the number of basis functions set to k = 20 by default. All this data is passed to Stan, and the model is compiled and fit using cmdstanr via the meanfield variational inference (VI) algorithm, as the data are much too large for MCMC sampling to be truly feasible. All this is implemented in my bayesVG R package, which is now available on GitHub.

Notation

The full model is as follows, where e_{ig} is the normalized, scaled expression of gene g at spot i, \beta_0 is a global intercept, \tau is the amplitude (marginal SD) of the GP, \Phi is the matrix of basis functions, \boldsymbol{\alpha} contains the coefficients for the basis functions, and \sigma_e is observational noise.

e_{ig} \sim \text{Gaussian} \left(\beta_0 + \tau_{g(i)}^2 \boldsymbol{\phi}(s(i))^\intercal \ \boldsymbol{\alpha}_{g(i)}, \sigma^2_e \right)

Example code

To make things easier I’ve included an example spatial transcriptomics dataset from 10X Genomics in the bayesVG package, which can itself be installed from GitHub using remotes::install_github("jr-leary7/bayesVG"). I have the most recent versions of CmdStan (v0.8.1) and cmdstanr (v2.36.0) installed, and I compiled CmdStan using cpp_options = list("CXXFLAGS += -O3 -march=native -mtune=native").

Packages

library(dplyr)
library(Seurat)
library(bayesVG)
library(cmdstanr)

Data

After loading the mouse brain dataset into memory, I filter out low-quality genes and spots like so:


data("seu_brain")
gene_set_1 <- Matrix::rowSums(seu_brain@assays$Spatial$counts > 0) >= 10L
gene_set_2 <- Matrix::rowMeans(seu_brain@assays$Spatial$counts) >= 0.1
genes_keep <- rownames(seu_brain)[gene_set_1 & gene_set_2]
seu_brain <- subset(seu_brain, features = genes_keep)
spot_set_1 <- seu_brain$nCount_Spatial >= 500L
spot_set_2 <- seu_brain$nFeature_Spatial >= 1000L
spots_keep <- colnames(seu_brain)[spot_set_1 & spot_set_2]
seu_brain <- subset(seu_brain, cells = spots_keep)
seu_brain <- PercentageFeatureSet(seu_brain, 
                                  pattern = "^mt-", 
                                  col.name = "percent_mito")
seu_brain <- subset(seu_brain, subset = percent_mito < 20)

The next step is to normalize and scale the data, and identify a set of 3000 naive highly variable genes (HVGs), which we’ll use as a candidate set of SVGs.

seu_brain <- SCTransform(seu_brain,
                         assay = "Spatial",
                         variable.features.n = 3000L,
                         vst.flavor = "v2",
                         return.only.var.genes = FALSE,
                         seed.use = 312,
                         verbose = FALSE)

I prepare the long-format expression data.frame and the spatial location matrix like so:

spatial_df <- GetTissueCoordinates(seu_brain) %>%
              relocate(cell) %>%
              rename(spot = cell)
spatial_mtx <- scale(as.matrix(select(spatial_df, -spot)))
expr_mtx <- GetAssayData(seu_brain,
                         assay = "SCT",
                         layer = "data")
expr_df <- as.data.frame(expr_mtx[VariableFeatures(seu_brain), ]) %>%
           mutate(gene = rownames(.), .before = 1) %>%
           tidyr::pivot_longer(cols = !gene,
                               names_to = "spot",
                               values_to = "expression") %>%
           mutate(gene = factor(gene, levels = unique(gene)),
                  spot = factor(spot, levels = unique(spot)), 
                  gene_expression = as.numeric(scale(gene_expression))) %>% 
           as.data.frame()

The next step is to create the basis functions using k-means plus the Matérn 3/2 kernel:

M <- nrow(spatial_mtx)
k <- 20
kmeans_centers <- kmeans(spatial_mtx, centers = k, iter.max = 20L)$centers
dists_centers <- as.matrix(dist(kmeans_centers))
lscale <- median(dists_centers[upper.tri(dists_centers)])
phi <- matrix(0, nrow = M, ncol = k)
for (i in seq(k)) {
  d2 <- rowSums((spatial_mtx - matrix(kmeans_centers[i, ], nrow = M, ncol = 2, byrow = TRUE))^2)
  phi[, i] <- bayesVG:::maternKernel(d2, 
                                     length.scale = lscale, 
                                     nu = 1.5)
}
phi <- scale(phi)
attributes(phi)[2:3] <- NULL

Finally, we can put all the data into a list to be passed to CmdStan:

data_list <- list(M = M,
                  N = nrow(expr_df),
                  G = length(unique(expr_df$gene)),
                  k = k,
                  spot_id = as.integer(expr_df$spot),
                  gene_id = as.integer(expr_df$gene),
                  phi = phi,
                  y = expr_df$gene_expression)

Modeling

The file approxGP.stan is included in bayesVG and contains the following code:

data {
  int<lower=1> M;  // number of spots
  int<lower=1> N;  // number of gene-spot pairs in long dataframe
  int<lower=1> G;  // number of genes
  int<lower=1> k;  // number of basis functions used to approximate GP
  array[N] int<lower=1, upper=M> spot_id;  // unique ID for each spot
  array[N] int<lower=1, upper=G> gene_id;  // unique ID for each gene
  matrix[M, k] phi;  // matrix of basis functions used to approximate GP
  vector[N] y;  // vector of normalized, scaled gene expression used as response variable
}

parameters {
  real beta0;  // global intercept
  matrix[k, G] alpha_t;  // transposed matrix of gene-specific coefficients for each basis function
  real<lower=0> sigma_y;  // observation noise of response variable
  vector<lower=0>[G] amplitude;  // vector of gene-specific amplitudes of the approximate GP
  real mu_amplitude;  // mean for the amplitude
  real<lower=0> sigma_amplitude;  // SD for the amplitude
  vector[k] mu_alpha;  // vector of means for the basis function coefficients
  vector<lower=0>[k] sigma_alpha;  // vector of SDs for the basis function coefficients
}

model {
  beta0 ~ std_normal();
  mu_alpha ~ normal(0, 2);
  sigma_alpha ~ normal(0, 2);
  sigma_y ~ normal(0, 2);
  mu_amplitude ~ normal(0, 2);
  sigma_amplitude ~ std_normal();
  vector[G] amplitude_sq = square(amplitude);
  matrix[M, G] phi_alpha = phi * alpha_t;
  vector[N] w;
  for (i in 1:N) {
    w[i] = phi_alpha[spot_id[i], gene_id[i]];
  }
  for (i in 1:G) {
    alpha_t[, i] ~ normal(mu_alpha, sigma_alpha);
    amplitude[i] ~ lognormal(mu_amplitude, sigma_amplitude);
  }
  y ~ normal(beta0 + amplitude_sq[gene_id] .* w, sigma_y);
}

Now we can compile the model, after which we perform an MLE optimization step which will be used as an initialization for the full VI model.

stan_file <- system.file("approxGP.stan", package = "bayesVG")
mod <- cmdstan_model(stan_file, compile = FALSE)
mod$compile(pedantic = TRUE, 
            stanc_options = list("O1"), 
            force_recompile = TRUE, 
            threads = 4L)
fit_optim <- mod$optimize(data_list, 
                          seed = 312, 
                          init = 0, 
                          jacobian = FALSE, 
                          iter = 3000L)

The VI fitting step is then performed:

fit_vi <- mod$variational(data_list,
                          seed = 312,
                          init = fit_optim, 
                          algorithm = "meanfield",
                          iter = 4000L,
                          draws = 1000L)

The final (I promise) step is to extract posterior draws for the estimated gene-specific amplitude parameter \hat{\tau}_g, which is what we use to rank genes by their spatial variability:

gene_mapping <- data.frame(gene = as.character(expr_df$gene),
                           gene_id = as.character(as.integer(expr_df$gene))) %>%
                distinct()
amplitude_summary <- fit_vi$summary(variables = "amplitude") %>%
                     rename_with(~paste0("amplitude_", .), .cols = -1) %>%
                     rename(amplitude_ci_ll = amplitude_q5,
                            amplitude_ci_ul = amplitude_q95) %>%
                     mutate(gene_id = sub("^.*\\[(.*)\\].*$", "\\1", variable), .before = 1) %>%
                     inner_join(gene_mapping, by = "gene_id") %>%
                     relocate(gene) %>%
                     select(-c(variable, gene_id)) %>%
                     mutate(amplitude_dispersion = amplitude_sd^2 / amplitude_mean) %>%
                     arrange(desc(amplitude_mean)) %>%
                     mutate(amplitude_mean_rank = row_number()) %>%
                     as.data.frame() %>% 
                     magrittr::set_rownames(.$gene)

Questions

All of this works pretty well - the estimated SVG set corresponds to the underlying biology accurately, and the top SVGs exhibit clear spatial dependency when visualized. However, the fitting process takes a fair amount of time when testing a small dataset of, say, 3000 genes across 2500 spots - usually around 45mins when performing both the MLE optimization and VI steps. I’d like to bring this number down by parallelizing using reduce_sum() and partial_sum() if possible, as runtime is a big concern for users with much larger spatial datasets.

I know the first step is to write a function in the functions block that computes the partial likelihood given a slice of the data, but I’m unsure of how the slicing should be performed when I also need to perform indexing by spot_id and gene_id. I’m also unsure as to how to correctly pack all the parameters / data into a vector to be provided to reduce_sum(); from the docs it appears that you can pass as many arguments as you need, with data arguments preceded by the data prefix, but in practice this has thrown errors related to the number of arguments. I’m also confused by the int dummy argument I’ve seen passed to reduce_sum() in several of the examples I’ve read through online (one example) - what does this do and why is it necessary ? I can attach the code I’ve tried if needed, but it doesn’t currently compile so I’m not sure that it’s useful, and this post is probably already too long.

Any guidance or ideas are greatly appreciated !

-jack

The only way you’re going to get a speedup with parallelization is if the compute done in each shard justifies the communication costs. With a GP, that should be the case if there are multiple solves being done in the shards. But then it looks like you’ve finessed the GP to the point where I don’t see any kind of solves being required.

Have you checked that this improves things? Initializing at a mode can often lead to seriously problematic behavior because as the highest density point, it’s hard to move away from. I don’t know how big your k and G are, but if k * G < 200 or so, you should be able to do a reasonable Laplace approximation if your density has a mode. That might outperform ADVI for log concave targets.

I’m not exactly sure what you’re trying to do, but you can’t use data as a prefix in a function call.

Have you tried pathfinder? I’m just curious how it’d work. Our implementation of ADVI is known to be unstable in terms of step size selection and using too few draws for the nested ELBO eval, at least under its default arguments. Have you done things like posterior predictive checks to make sure the fits are doing what you want?

Also, you might want to evaluate very short runs of MCMC initialized from a draw from the variational fit. This can often clean up the VI fit considerably at relatively low cost.

Thanks for the input ! In my case k \times G is usually around 60,000, since the number of genes G is usually set to 3,000. Thus it seems Laplace approximation is out.

I have tried Pathfinder as well – it actually only ran without error when I initialized from the mode, otherwise all 4 paths would error out. Is that strange behavior, or would it make sense given the relative sparsity and size of the data ? Once I got it to run via the mode initialization it worked well and provided visually correct results, but it took about twice as long as the meanfield algorithm given the same resources.

Cleaning up the VI fit via MCMC sounds reasonable – I’ll give that a try.

Thanks again,
jack

Yes. If you have a reproducible example you could share, that’d be helpful.

It’s hard for me to imagine a situation where Pathfinder took twice as long as mean field. Did you also initialize mean field at the mode?

yes, i initialized both the meanfield and Pathfinder algorithms at the mode.

here’s a reproducible example using the mouse brain dataset that i’ve included in bayesVG. installing bayesVG from GitHub should install all the necessary dependencies, and I’m assuming you already have CmdStan and cmdstanr installed & set up.

# install packages 
remotes::install_github("jr-leary7/bayesVG")

# libraries
library(Seurat)
library(bayesVG)

# load & normalize data -- test only 500 most variable genes for speed's sake 
data("seu_brain", package = "bayesVG")
seu_brain <- SCTransform(seu_brain, 
                         assay = "Spatial", 
                         variable.features.n = 500L, 
                         vst.flavor = "v2", 
                         seed.use = 312, 
                         verbose = FALSE)
seu_brain_mf <- seu_brain
seu_brain_pf <- seu_brain
rm(seu_brain)

# run bayesVG model -- meanfield algorithm 
mf_start <- Sys.time()
seu_brain_mf <- findSpatiallyVariableFeaturesBayes(seu_brain_mf, 
                                                   naive.hvgs = VariableFeatures(seu_brain_mf),
                                                   algorithm = "meanfield", 
                                                   mle.init = TRUE,  # initialize at the mode 
                                                   kernel = "matern", 
                                                   kernel.smoothness = 1.5, 
                                                   n.cores = 2L,  # used for compilation only 
                                                   save.model = TRUE)
mf_end <- Sys.time()
mf_diff <- mf_end - mf_start
print(mf_diff)

# bayesVG model -- pathfinder algorithm 
pf_start <- Sys.time()
seu_brain_pf <- findSpatiallyVariableFeaturesBayes(seu_brain_pf, 
                                                   naive.hvgs = VariableFeatures(seu_brain_pf),
                                                   algorithm = "pathfinder", 
                                                   mle.init = TRUE,  # initialize at the mode 
                                                   kernel = "matern", 
                                                   kernel.smoothness = 1.5, 
                                                   n.cores = 4L,  # used for compilation & pathfinder  
                                                   save.model = TRUE)
pf_end <- Sys.time()
pf_diff <- pf_end - pf_start
print(pf_diff)

in this example (i think) it would be really odd if Pathfinder didn’t outperform meanfield as it uses 4 threads to compute the 4 paths in parallel, whereas as you know the meanfield version of the function is single-threaded. however, in practice i’ve seen that Pathfinder takes a fair bit longer even with this caveat.

also, the save.model argument of findSpatiallyVariableFeaturesBayes() saves the model from cmdstanr to the object’s unstructured metadata, from whence it can be retrieved like so:

mod_mf <- extractModel(seu_brain_mf)
mod_pf <- extractModel(seu_brain_pf)

this allows the inspection of the final fit, generation of custom posterior draws, etc.

I probalby have some old versions lying around. I exclusively work in Python these days. From your config, are you using single-path or multi-path Pathfinder? Unless the log density calculation is parallelized, using multiple cores might not add anyhting.

Both of these algorithms involve touchy optimization steps. What you’re calling “meanfield” is the autodiff variational inference (ADVI) algorithm with a diagonal covariance approximation. Its stepsize adaptation is not quite right and it can get stuck during optimization. With ADVI, we estimate the KL divergence through the ELBO using a small number of draws from the approximating normal distribution—it’s much lower variance if you use more, but it gets very expensive without a GPU and we don’t have Stan set up to use GPUs that efficiently (the evaluations are through JAX implementations).

With Pathfinder, we use L-BFGS optimization. That uses local estimate of curvature to help guide the optimizer and also to approximate covariance. This can be inaccurate in complicated posterior geometries and lead to the optimzer getting stuck. The paper talks about the case of very flat (poorly identified) target densities being a particular issue.

Initializing Pathfinder at the mode is problematic. It needs to take several steps to estimate covariance, and it will get stuck at the mode. Also, there are a lot of problems that don’t have a mode.

There’s also tuning parameters for capping the number of optimization steps in Pathfinder and the number of steps overall in ADVI. This can have a big effect on performance.

As far as performance with multiple threads or processes goes, it really depends on a few things. A big one is your memory architecutre, which will determinw how much CPU you can actually use in 4 threads versus how much time the CPU will be busy-waiting for memory. Also, running four chains of mulit-path Pathfinder might not be worth it versus 1 (or versus 20) in an accuracy per compute basis. It will depend on the problem.

I really appreciate the detailed response ! Here’s what I use to install and set up CmdStan:

install.packages("cmdstanr", repos = c("https://stan-dev.r-universe.dev", getOption("repos")))
library(cmdstanr)
install_cmdstan(cores = 4L,
                overwrite = TRUE,
                cpp_options = list("CXXFLAGS += -O3 -march=native -mtune=native"))

From your response it seems like I definitely shouldn’t be using the mode initialization for Pathfinder, but it’s unfortunately the only way I’ve gotten the algorithm to run without all 4 paths immediately erroring out. I can find the correct error message and attach it but it’ll take me a few minutes to write up the example.

I can increase the number of MCMC samples used to estimate the ELBO in the meanfield ADVI algo via the elbo_samples argument – it looks like the default is 100 samples (source), would increasing this to say, 150 help much or is that too small a difference to be worth it ?

As for my setup, I’m using an M2 Mac Mini with 8 cores. I thought as per the documentation that when running multi-path Pathfinder (with 4 paths as the default) the num_threads parameter would parallelize across paths:

num_threads (positive integer) If the model was compiled with threading support, the number of threads to use in parallelized sections (e.g., for multi-path pathfinder as well as reduce_sum).

I also have access to a compute cluster with pretty good CPUs and Nvidia GPUs. I assume using that (especially for larger datasets) and compiling with OpenCL support would increase performance ? I can’t compile with OpenCL on my M2 Mac since the GPU doesn’t support double-precision floating points for some reason.

I was just talking to @stevebronder about this and he says he uses initialization at the model to test error handling for Pathfinder. It’s definitely not going to work.

The 100 is much higher than it used to be. At that point, there’s an issue as to whether ADVI is going to be faster than MCMC. To stabilize it fully requires on the order of 10K evals, which is definitely prohibitive without a GPU and we don’t provide a GPU implementation. Variance of estimator goes down as 1 / sqrt(N).

Yes, it should. I was really just asking if you were running single or multi-path. But if you’re initializing at the mode, it’s not going to matter.

If Pathfinder errors out, can you try running optimization? About the only place it should error out is during optimization.

Here’s an example of how I would manually set up and fit the models (i.e., not using the bayesVG::findSpatiallyVariableFeaturesBayes() function). The optimization and meanfield algorithms work just fine, but Pathfinder errors out in ~30s.

##### load necessary packages ######
library(dplyr)
library(Seurat)
library(bayesVG)
library(cmdstanr)

##### load data, normalize, & identify naive HVGs #####
data("seu_brain", package = "bayesVG")
seu_brain <- SCTransform(seu_brain, 
                         assay = "Spatial", 
                         variable.features.n = 3000L, 
                         return.only.var.genes = FALSE, 
                         seed.use = 312, 
                         verbose = FALSE)

##### extract spatial & expression matrices #####
spatial_df <- select(GetTissueCoordinates(seu_brain), -cell)
spatial_mtx <- scale(as.matrix(spatial_df))
expr_mtx <- GetAssayData(seu_brain,
                         assay = "SCT",
                         layer = "data")
expr_df <- as.data.frame(expr_mtx[VariableFeatures(seu_brain), ]) %>%
           mutate(gene = rownames(.), .before = 1) %>%
           tidyr::pivot_longer(cols = !gene,
                               names_to = "spot",
                               values_to = "expression") %>%
           mutate(gene = factor(gene, levels = unique(gene)),
                  spot = factor(spot, levels = unique(spot)), 
                  expression = as.numeric(scale(expression))) %>% 
           as.data.frame()

##### compute basis function matrix #####
M <- nrow(spatial_mtx)
k <- 20
kmeans_centers <- kmeans(spatial_mtx, centers = k, iter.max = 20L)$centers
dists_centers <- as.matrix(dist(kmeans_centers))
lscale <- median(dists_centers[upper.tri(dists_centers)])
phi <- matrix(0, nrow = M, ncol = k)
for (i in seq(k)) {
  d2 <- rowSums((spatial_mtx - matrix(kmeans_centers[i, ], nrow = M, ncol = 2, byrow = TRUE))^2)
  phi[, i] <- bayesVG:::maternKernel(d2, length.scale = lscale, nu = 1.5)
}
phi <- scale(phi)
attributes(phi)[2:3] <- NULL

##### prepare data & compile model #####
data_list <- list(M = M,
                  N = nrow(expr_df),
                  G = length(unique(expr_df$gene)),
                  k = k,
                  spot_id = as.integer(expr_df$spot),
                  gene_id = as.integer(expr_df$gene),
                  phi = phi,
                  y = expr_df$expression)
stan_file <- system.file("approxGP.stan", package = "bayesVG")
mod <- cmdstan_model(stan_file, compile = FALSE)
mod$compile(pedantic = TRUE, 
            stanc_options = list("O1"), 
            force_recompile = TRUE, 
            threads = 4L)

##### fit models #####
fit_optim <- mod$optimize(data_list, 
                          seed = 312, 
                          init = 0, 
                          jacobian = FALSE, 
                          iter = 3000L)
fit_mf <- mod$variational(data_list,
                          seed = 312,
                          algorithm = "meanfield",
                          iter = 4000L,
                          draws = 1000L)
fit_pf <- mod$pathfinder(data_list, 
                         seed = 312, 
                         num_threads = 4L, 
                         draws = 1000L, 
                         num_paths = 4L)

The error I get for each path (showing only Path 1 for brevity’s sake) when running Pathfinder is below:

Path [1] :Initial log joint density = -86127675758.523758 
Path [1] : Iter      log prob        ||dx||      ||grad||     alpha      alpha0      # evals       ELBO    Best ELBO        Notes  
              1      -8.613e+10      0.000e+00   1.737e+11    1.000e-03  1.000e-03         1        nan       -inf                   
Error evaluating model log probability: Non-finite gradient. 
Error evaluating model log probability: Non-finite gradient. 
Error evaluating model log probability: Non-finite gradient. 
Error evaluating model log probability: Non-finite gradient. 
Error evaluating model log probability: Non-finite gradient. 
Error evaluating model log probability: Non-finite gradient. 
Error evaluating model log probability: Non-finite gradient. 
Error evaluating model log probability: Non-finite gradient. 
Error evaluating model log probability: Non-finite gradient. 
Error evaluating model log probability: Non-finite gradient. 
Error evaluating model log probability: Non-finite gradient. 
Chain 1 Optimization terminated with error: Line search failed to achieve a sufficient decrease, no more progress can be made Optimization failed to start, pathfinder cannot be run.
Chain 1 Pathfinder iteration: 0 failed.

Thanks. This is immediately getting into a bad place in the parameter space. Our random initialization assumes the posterior is roughly standard normal, so initializes on that scale. If it’s shifted from being centered at the origin or badly scaled, it can be very hard to initialize.

There are several things you can do to make this behave better. One is to transform parameters so that the posterior is more normal. For instance, if you have a scalar parameter alpha that takes on values of 1000 with a standard deviation of 200 in the posterior, you can declare it as

real<offset=1000, multiplier=200> alpha;

This makes the unconstrained value (alpha - 1000) / 200, which should be more unit scaled.

The other thing you can do is reduce the variance of the init. If you give it inits arguments of 0.5, then it will initialize uniform(-0.5, 0.5) rather than the default of uniform(-2, 2). This number can even be zero, in which case every parameter will be initialized to zero (on the unconstrained scale).

apologies for the late reply – that all makes sense. after some further tinkering i found that initializing the pathfinder algorithm with 0 instead of the default \boldsymbol{\theta} \sim \text{Uniform}(-2, 2) initialization lead to much better performance. thank you again for all your help !