Hi all,
I’m currently working on what I think - based on perusing these forums, the Stan
docs, the brms
and cmdstanr
docs, etc. - is a new-ish type of Gaussian process model. The model combines the Hilbert space approximation and a hierarchical structure via hyperpriors. What I’m hoping to get is a basic sanity check on my approach, as well as perhaps some guidance on how to better optimize my Stan
code as I’m relatively new to the language. As of now, the model compiles correctly and the results mostly make sense, so I’m not dealing with any explicit errors or warnings.
The project I’m working on concerns spatial transcriptomics data i.e., an n_s \times n_g spot-by-gene matrix of normalized & scaled (thus approximately Gaussian-distributed) gene expression. This is then converted to a long-format data.frame
with columns gene
, spot
, and expression
having n_s \times n_g rows. Each spot also has an associated (x, y) coordinate specifying its location in Euclidean space.
As an example, consider the brain dataset from 10X Genomics made available in the SeuratData
package. First I load the data, then filter out low-quality genes i.e., those expressed in <10 spots or those with a mean expression of <0.1.
# load packages
library(dplyr)
library(Seurat)
library(cmdstanr)
# SeuratData::InstallData("stxBrain") -- run once if not run previously
seu_brain <- SeuratData::LoadData("stxBrain", type = "anterior1")
gene_set_1 <- Matrix::rowSums(seu_brain@assays$Spatial$counts > 0) >= 10L
gene_set_2 <- Matrix::rowMeans(seu_brain@assays$Spatial$counts) >= 0.1
genes_keep <- rownames(seu_brain)[gene_set_1 & gene_set_2]
seu_brain <- subset(seu_brain, features = genes_keep)
Next, the raw, integer-valued expression data are normalized using the SCTransform
tool:
seu_brain <- SCTransform(seu_brain,
assay = "Spatial",
variable.features.n = 3000L,
vst.flavor = "v2",
return.only.var.genes = FALSE,
seed.use = 312,
verbose = FALSE)
I then estimate a global length-scale \ell and a matrix of basis functions \boldsymbol{\phi} to approximate the GP as a Hilbert space like so, assuming an exponentiated quadratic covariance kernel:
# extract spatial coordinates
spatial_df <- GetTissueCoordinates(seu_brain) %>%
relocate(cell) %>%
rename(spot = cell)
spatial_mtx <- scale(as.matrix(select(spatial_df, -spot)))
# convert gene expression matrix to long-format
expr_mtx <- GetAssayData(seu_brain,
assay = "SCT",
layer = "scale.data")
expr_df <- as.data.frame(expr_mtx) %>%
mutate(gene = rownames(.), .before = 1) %>%
tidyr::pivot_longer(cols = !gene,
names_to = "spot",
values_to = "expression") %>%
mutate(gene = factor(gene, levels = unique(gene)),
spot = factor(spot, levels = unique(spot)))
# estimate global length-scale and matrix of basis functions
M <- nrow(spatial_mtx)
k <- 20
kmeans_centers <- kmeans(spatial_mtx, centers = k, iter.max = 20L)$centers
dists_centers <- as.matrix(dist(kmeans_centers))
lscale <- median(dists_centers[upper.tri(dists_centers)])
phi <- matrix(0, nrow = M, ncol = k)
for (i in seq(k)) {
d2 <- rowSums((spatial_mtx - matrix(kmeans_centers[i, ], nrow = M, ncol = 2, byrow = TRUE))^2)
phi[, i] <- exp(-d2 / (2 * lscale^2))
}
The Stan
model is stored in the file gp-test.stan
and contains the following code:
data {
int<lower=1> M; // number of spots
int<lower=1> N; // number of gene-spot pairs in long dataframe
int<lower=1> G; // number of genes
int<lower=1> k; // number of basis functions used to approximate GP
array[N] int<lower=1, upper=M> spot_id; // unique ID for each spot
array[N] int<lower=1, upper=G> gene_id; // unique ID for each gene
matrix[M, k] phi; // matrix of basis functions used to approximate GP
vector[N] y; // vector of normalized gene expression used as response variable
}
parameters {
vector[G] beta0; // vector of gene-specific intercepts
matrix[G, k] alpha; // matrix of gene-specific coefficients for each basis function
real<lower=0> sigma_y; // observational noise of response variable (normalized gene expression)
vector<lower=0>[G] amplitude; // vector of gene-specific marginal standard deviations of the approximate GP
real mu_amplitude; // mean for the marginal standard deviation
real<lower=0> sigma_amplitude; // variance for the marginal standard deviation
real mu_beta0; // mean for the gene-specific intercepts
real<lower=0> sigma_beta0; // variance for the gene-specific intercepts
vector[k] mu_alpha; // vector of means for the basis function coefficients
vector<lower=0>[k] sigma_alpha; // variances for the bsis function coefficients
}
model {
mu_beta0 ~ normal(0, 5);
sigma_beta0 ~ normal(0, 2);
mu_alpha ~ normal(0, 2);
sigma_alpha ~ normal(0, 2);
amplitude ~ lognormal(0, 1);
sigma_y ~ normal(0, 2);
mu_amplitude ~ normal(0, 1);
sigma_amplitude ~ normal(0, 1);
for (i in 1:G) {
beta0[i] ~ normal(mu_beta0, sigma_beta0);
for (j in 1:k) {
alpha[i, j] ~ normal(mu_alpha[j], sigma_alpha[j]);
}
amplitude[i] ~ lognormal(mu_amplitude, sigma_amplitude);
}
for (i in 1:N) {
int g = gene_id[i];
int p = spot_id[i];
real w_i = dot_product(phi[p], alpha[g]);
y[i] ~ normal(beta0[g] + amplitude[g]^2 * w_i, sigma_y);
}
}
Lastly, I pass the data to Stan
and fit the model using cmdstanr
with the meanfield variation inference (VI) algorithm as shown below. As a note, I’m using the cmdstanr
package v0.8.0 and CmdStan
v2.36.0.
data_list <- list(M = M,
N = nrow(expr_df),
G = length(unique(expr_df$gene)),
k = k,
spot_id = expr_df$spot,
gene_id = expr_df$gene,
phi = phi,
y = expr_df$expression)
mod <- cmdstan_model("../gp-test.stan",
stanc_options = list("O1"),
threads = 2L)
fit_vi <- mod$variational(data_list,
seed = 312,
algorithm = "meanfield",
iter = 3000L,
draws = 1000L)
draws_vi <- fit_vi$draws()
draws_vi_summary <- fit_vi$summary()
Does the above approach make sense as a means of estimating a hierarchical approximate GP ? If not, in what way have I misspecified the model / assumptions / data ? Are the basic priors I’ve chosen serviceable, or should they be modified ? Does using the VI approach instead of MCMC sampling via the NUTS algorithm (for speed’s sake, mostly) make sense in this case, or is VI known to behave poorly for this type of application ? Any and all guidance is greatly appreciated !
Thanks much,
Jack