Hi all,
In short, I’m rather confused about how to implement within-chain parallelization for a somewhat complex model using reduce_sum()
and partial_sum()
. The examples I’ve found online ([1], [2], [3]) seem to be for pretty simple models, and I’m specifically looking to understand 1) how to correctly perform the slicing into chunks and 2) how to correctly pass parameters, data, and non-model parameters such as the grainsize to reduce_sum()
.
Background
I’ve posted here recently about my progress on fitting an approximate multivariate hierarchical Gaussian process (GP) model with the goal of identify spatially variable genes (SVGs) in spatial transcriptomics data. The data comes in 2 pieces - an n_{g} \times n_s gene-by-spot normalized counts matrix, and an n_s \times 2 spatial location matrix for each spot. The counts matrix is then converted to a long-format data.frame
having n_s \times n_g rows and columns gene
, spot
, and gene_expression
, where gene
and spot
are ID variables and gene_expression
contains the normalized, scaled expression of each gene at each spot. I then approximate the GP as a Hilbert space using a basis function matrix \Phi via the Matérn 3/2 kernel, with the number of basis functions set to k = 20 by default. All this data is passed to Stan, and the model is compiled and fit using cmdstanr
via the meanfield variational inference (VI) algorithm, as the data are much too large for MCMC sampling to be truly feasible. All this is implemented in my bayesVG
R package, which is now available on GitHub.
Notation
The full model is as follows, where e_{ig} is the normalized, scaled expression of gene g at spot i, \beta_0 is a global intercept, \tau is the amplitude (marginal SD) of the GP, \Phi is the matrix of basis functions, \boldsymbol{\alpha} contains the coefficients for the basis functions, and \sigma_e is observational noise.
Example code
To make things easier I’ve included an example spatial transcriptomics dataset from 10X Genomics in the bayesVG
package, which can itself be installed from GitHub using remotes::install_github("jr-leary7/bayesVG")
. I have the most recent versions of CmdStan
(v0.8.1) and cmdstanr
(v2.36.0) installed, and I compiled CmdStan
using cpp_options = list("CXXFLAGS += -O3 -march=native -mtune=native")
.
Packages
library(dplyr)
library(Seurat)
library(bayesVG)
library(cmdstanr)
Data
After loading the mouse brain dataset into memory, I filter out low-quality genes and spots like so:
data("seu_brain")
gene_set_1 <- Matrix::rowSums(seu_brain@assays$Spatial$counts > 0) >= 10L
gene_set_2 <- Matrix::rowMeans(seu_brain@assays$Spatial$counts) >= 0.1
genes_keep <- rownames(seu_brain)[gene_set_1 & gene_set_2]
seu_brain <- subset(seu_brain, features = genes_keep)
spot_set_1 <- seu_brain$nCount_Spatial >= 500L
spot_set_2 <- seu_brain$nFeature_Spatial >= 1000L
spots_keep <- colnames(seu_brain)[spot_set_1 & spot_set_2]
seu_brain <- subset(seu_brain, cells = spots_keep)
seu_brain <- PercentageFeatureSet(seu_brain,
pattern = "^mt-",
col.name = "percent_mito")
seu_brain <- subset(seu_brain, subset = percent_mito < 20)
The next step is to normalize and scale the data, and identify a set of 3000 naive highly variable genes (HVGs), which we’ll use as a candidate set of SVGs.
seu_brain <- SCTransform(seu_brain,
assay = "Spatial",
variable.features.n = 3000L,
vst.flavor = "v2",
return.only.var.genes = FALSE,
seed.use = 312,
verbose = FALSE)
I prepare the long-format expression data.frame
and the spatial location matrix like so:
spatial_df <- GetTissueCoordinates(seu_brain) %>%
relocate(cell) %>%
rename(spot = cell)
spatial_mtx <- scale(as.matrix(select(spatial_df, -spot)))
expr_mtx <- GetAssayData(seu_brain,
assay = "SCT",
layer = "data")
expr_df <- as.data.frame(expr_mtx[VariableFeatures(seu_brain), ]) %>%
mutate(gene = rownames(.), .before = 1) %>%
tidyr::pivot_longer(cols = !gene,
names_to = "spot",
values_to = "expression") %>%
mutate(gene = factor(gene, levels = unique(gene)),
spot = factor(spot, levels = unique(spot)),
gene_expression = as.numeric(scale(gene_expression))) %>%
as.data.frame()
The next step is to create the basis functions using k-means plus the Matérn 3/2 kernel:
M <- nrow(spatial_mtx)
k <- 20
kmeans_centers <- kmeans(spatial_mtx, centers = k, iter.max = 20L)$centers
dists_centers <- as.matrix(dist(kmeans_centers))
lscale <- median(dists_centers[upper.tri(dists_centers)])
phi <- matrix(0, nrow = M, ncol = k)
for (i in seq(k)) {
d2 <- rowSums((spatial_mtx - matrix(kmeans_centers[i, ], nrow = M, ncol = 2, byrow = TRUE))^2)
phi[, i] <- bayesVG:::maternKernel(d2,
length.scale = lscale,
nu = 1.5)
}
phi <- scale(phi)
attributes(phi)[2:3] <- NULL
Finally, we can put all the data into a list to be passed to CmdStan
:
data_list <- list(M = M,
N = nrow(expr_df),
G = length(unique(expr_df$gene)),
k = k,
spot_id = as.integer(expr_df$spot),
gene_id = as.integer(expr_df$gene),
phi = phi,
y = expr_df$gene_expression)
Modeling
The file approxGP.stan
is included in bayesVG
and contains the following code:
data {
int<lower=1> M; // number of spots
int<lower=1> N; // number of gene-spot pairs in long dataframe
int<lower=1> G; // number of genes
int<lower=1> k; // number of basis functions used to approximate GP
array[N] int<lower=1, upper=M> spot_id; // unique ID for each spot
array[N] int<lower=1, upper=G> gene_id; // unique ID for each gene
matrix[M, k] phi; // matrix of basis functions used to approximate GP
vector[N] y; // vector of normalized, scaled gene expression used as response variable
}
parameters {
real beta0; // global intercept
matrix[k, G] alpha_t; // transposed matrix of gene-specific coefficients for each basis function
real<lower=0> sigma_y; // observation noise of response variable
vector<lower=0>[G] amplitude; // vector of gene-specific amplitudes of the approximate GP
real mu_amplitude; // mean for the amplitude
real<lower=0> sigma_amplitude; // SD for the amplitude
vector[k] mu_alpha; // vector of means for the basis function coefficients
vector<lower=0>[k] sigma_alpha; // vector of SDs for the basis function coefficients
}
model {
beta0 ~ std_normal();
mu_alpha ~ normal(0, 2);
sigma_alpha ~ normal(0, 2);
sigma_y ~ normal(0, 2);
mu_amplitude ~ normal(0, 2);
sigma_amplitude ~ std_normal();
vector[G] amplitude_sq = square(amplitude);
matrix[M, G] phi_alpha = phi * alpha_t;
vector[N] w;
for (i in 1:N) {
w[i] = phi_alpha[spot_id[i], gene_id[i]];
}
for (i in 1:G) {
alpha_t[, i] ~ normal(mu_alpha, sigma_alpha);
amplitude[i] ~ lognormal(mu_amplitude, sigma_amplitude);
}
y ~ normal(beta0 + amplitude_sq[gene_id] .* w, sigma_y);
}
Now we can compile the model, after which we perform an MLE optimization step which will be used as an initialization for the full VI model.
stan_file <- system.file("approxGP.stan", package = "bayesVG")
mod <- cmdstan_model(stan_file, compile = FALSE)
mod$compile(pedantic = TRUE,
stanc_options = list("O1"),
force_recompile = TRUE,
threads = 4L)
fit_optim <- mod$optimize(data_list,
seed = 312,
init = 0,
jacobian = FALSE,
iter = 3000L)
The VI fitting step is then performed:
fit_vi <- mod$variational(data_list,
seed = 312,
init = fit_optim,
algorithm = "meanfield",
iter = 4000L,
draws = 1000L)
The final (I promise) step is to extract posterior draws for the estimated gene-specific amplitude
parameter \hat{\tau}_g, which is what we use to rank genes by their spatial variability:
gene_mapping <- data.frame(gene = as.character(expr_df$gene),
gene_id = as.character(as.integer(expr_df$gene))) %>%
distinct()
amplitude_summary <- fit_vi$summary(variables = "amplitude") %>%
rename_with(~paste0("amplitude_", .), .cols = -1) %>%
rename(amplitude_ci_ll = amplitude_q5,
amplitude_ci_ul = amplitude_q95) %>%
mutate(gene_id = sub("^.*\\[(.*)\\].*$", "\\1", variable), .before = 1) %>%
inner_join(gene_mapping, by = "gene_id") %>%
relocate(gene) %>%
select(-c(variable, gene_id)) %>%
mutate(amplitude_dispersion = amplitude_sd^2 / amplitude_mean) %>%
arrange(desc(amplitude_mean)) %>%
mutate(amplitude_mean_rank = row_number()) %>%
as.data.frame() %>%
magrittr::set_rownames(.$gene)
Questions
All of this works pretty well - the estimated SVG set corresponds to the underlying biology accurately, and the top SVGs exhibit clear spatial dependency when visualized. However, the fitting process takes a fair amount of time when testing a small dataset of, say, 3000 genes across 2500 spots - usually around 45mins when performing both the MLE optimization and VI steps. I’d like to bring this number down by parallelizing using reduce_sum()
and partial_sum()
if possible, as runtime is a big concern for users with much larger spatial datasets.
I know the first step is to write a function in the functions
block that computes the partial likelihood given a slice of the data, but I’m unsure of how the slicing should be performed when I also need to perform indexing by spot_id
and gene_id
. I’m also unsure as to how to correctly pack all the parameters / data into a vector to be provided to reduce_sum()
; from the docs it appears that you can pass as many arguments as you need, with data arguments preceded by the data
prefix, but in practice this has thrown errors related to the number of arguments. I’m also confused by the int dummy
argument I’ve seen passed to reduce_sum()
in several of the examples I’ve read through online (one example) - what does this do and why is it necessary ? I can attach the code I’ve tried if needed, but it doesn’t currently compile so I’m not sure that it’s useful, and this post is probably already too long.
Any guidance or ideas are greatly appreciated !
-jack