Issues sampling from multivariate normal prior

Hello,

I am trying to fit a Poisson regression in Stan, through brms in R. My prior is specified as a multivariate normal distribution, with non-zero mean, and non-zero correlations between different variables.

I have two models, one with ~350 features, the other with ~700. I noticed the larger one was taking significantly longer to sample from (about 30x, even with the same number of data points). I would expect the larger model to take longer (maybe 4x or 8x) but not 30x. So I tried to debug by setting the sample_prior = "only" argument in brms, which, as the name suggests, ignores data and samples from the prior only. I fed the prior into the first model, and it took about 52 seconds for one chain, 500 warmup samples, 1750 total. For the second model, it takes well over 30 minutes [EDIT: first chain clocked in at 44 minutes] to achieve the same number of samples.

Here is the R code:

library(brms)
library(mvtnorm) # for initFunctions

load("priors.Rda")

prior1 <- set_prior("multi_normal(mean1, vcov1)")
stanvars1 <- stanvar(mean1) + stanvar(vcov1)

initFunction1 <- function() {
  list(b = rmvnorm(1, mean = mean1, sigma = vcov1)[1, ])
}

# Fake data, but shouldn't be used anyway
X = data.frame(t(as.matrix(c(rep(0, 100), 1, 1, rep(0, 261), 75), ncol = 1)))
colnames(X)[364] = 'y'

model1 <- brm(
  y ~ . + 0,
  family = brmsfamily("poisson", link = "identity"),
  data = X,
  prior = prior1,
  stanvars = stanvars1,
  init = initFunction1,
  chains = 4,
  iter = 1750, warmup = 500, cores = 4, refresh = 50,
  sample_prior = "only"
)


prior2 <- set_prior("multi_normal(mean2, vcov2)")
stanvars2 <- stanvar(mean2) + stanvar(vcov2)

initFunction2 <- function() {
  list(b = rmvnorm(1, mean = mean2, sigma = vcov2)[1, ])
}

# Fake data, but shouldn't be used anyway
X = data.frame(t(as.matrix(c(rep(0, 100), 50, rep(0, 300), 50, rep(0, 324), 100), ncol = 1)))
colnames(X)[727] = 'y'

model2 <- brm(
  y ~ . + 0,
  family = brmsfamily("poisson", link = "identity"),
  data = X,
  prior = prior2,
  stanvars = stanvars2,
  init = initFunction2,
  chains = 4,
  iter = 1750, warmup = 500, cores = 4, refresh = 50,
  sample_prior = "only"
)

priors.Rda is an Rda file containing mean1, vcov1, mean2, vcov2 (before I realized I couldn’t upload Rda). But the individual objects are attached to this post as CSVs. The sample X is fake data, but it does reflect the sparsity of my actual data (not that that matters when just sampling from the prior).

You’ll notice the scales of the two datasets are different. I tried to increase the scale of the second dataset to match that of the first, and it didn’t make any difference. I also tried to scale the first column in the second dataset (to make it more in line with the rest of the columns), and that also made no difference.

What gives? My model is just a simple Poisson GLM, so I doubt it is misspecified, plus I am just sampling from the prior anyway, so the model specification shouldn’t matter at all.

mean1.csv (8.0 KB)
mean2.csv (16.4 KB)
vcov1.csv (2.2 MB)
vcov2.csv (7.5 MB)

1 Like

OK, so have had some opportunity to debug further.

  1. If a take the first half of the columns in the second dataset, it samples just as fast as the first dataset (52 seconds).
  2. If I take the first dataset, and repeat it (doubling the number of columns, thus making it as big as the second dataset), putting random values in the covariance matrix, it takes a reasonable amount of time to sample (about 6 minutes).
  3. But somehow, the second dataset at its full size takes 45 minutes. Doesn’t make any sense

Does it speed the computations to use the multi_normal_cholesky instead of the multi_normal parametrization?

Precompute the cholesky decomposition:

chol2 <- t(chol(vcov2))

Update the prior specification:

prior2 <- set_prior("multi_normal_cholesky(mean2, chol2)")
stanvars2 <- stanvar(mean2) + stanvar(chol2)

Thanks for the suggestion, unfortunately, it didn’t make any difference (52 seconds for small, 45 minutes for big)

Are there any tuning parameters for the sampling algorithm I can use? I suspect its not an issue with the actual mathematical operations (considering point (2) in the previous comment), but that somehow the shape of the distribution is giving the sampler problems.

I see now that you are correct and the issue is indeed with the structure of vcov2. Plotting the actual numbers, vcov2 has three components: the first column corresponds to an intercept (as it’s uncorrelated with the rest), the remaining columns form two components of about the same size, say group1 and group2.

All the correlations are positive and the correlations between group1 and group2 are higher than the correlations within each group. Perhaps that’s the source of th issue? And in that case it might be worth it revisiting the choice of prior. Where does it come from?

Each row in the dataset is essentially the average of two ratings, one from each group (plus a possible intercept). Really, the dataset has one more column, but as including it would make the X matrix not full rank, it is dropped. I want to enforce the constraint that (sum of ratings in group1) = (sum of ratings in group 2 (including dropped coefficient)), so I have to introduce the positive correlation between the variables. Otherwise, if I consider the priors to be independent univariate normals, then with little data, the sampled values for the “reverse-engineered” dropped coefficient might not make any sense. This started with a question I asked on Cross-Validated SE here.

As for the source of the priors themselves, they are empirically derived, essentially by running the same regression on historical data (frequentist/MLE approach), and then adding a lot of variance to widen the resulting distributions.

1 Like

I don’t fully understand the model, so stick to the problem how to sample from the multivariate normal prior. What about doing in two steps: (1) sample b1; (2) sample b2 given b1 where b1 and b2 correspond to the two groups, so that the full parameter vector is b = [b1,b2]. That is, factorize the prior on b into two components of about the same dimension, since the groups are about the size. Here is the formula for the conditional MVN.

Here is my attempt (in Stan). It seems to run faster but you should check the math is right.

data {
  int<lower=1> K1;
  int<lower=1> K2;
  vector[K1] mu1;
  vector[K2] mu2;
  matrix[K1, K1] Sigma11;
  matrix[K1, K2] Sigma12;
  matrix[K2, K1] Sigma21;
  matrix[K2, K2] Sigma22;
}
transformed data {
  matrix[K1, K1] L11 = cholesky_decompose(Sigma11);
  matrix[K2, K2] L22 = cholesky_decompose(Sigma22);

  matrix[K1, K2] Sigma12_div_Sigma22 = Sigma12 * chol2inv(L22);
  matrix[K1, K1] Sigma11_given_b2 = Sigma11 - Sigma12_div_Sigma22 * Sigma21;
  matrix[K1, K1] L11_given_b2 = cholesky_decompose(Sigma11_given_b2);
}
parameters {
  vector[K1] b1;
  vector[K2] b2;
}
transformed parameters {
  vector[K1] mu1_given_b2 = mu1 + Sigma12_div_Sigma22 * (b2 - mu2);
}
model {
  b2 ~ multi_normal_cholesky(mu2, L22);
  b1 ~ multi_normal_cholesky(mu1_given_b2, L11_given_b2);
}
1 Like

Thanks! I’ll need to figure out how to run this “custom” code through brms (or figure out how to run stan code directly, one or the other), but this should be enough to get me started.

I discovered that on my machine it’s faster to run the chains sequentially rather than in parallel. (This may be true for brms as well; I didn’t check.) More importantly, the sample generated doesn’t have the expected covariance matrix; the means are reproduced okay. It worked reasonably well for smaller examples, so either the size or the special structure of vcov2 makes this extra hard.

If the sampling works better when sampling b directly in one step, then maybe accept the inefficiency? Have you checked that with sample_prior="only" you get samples from the specified prior?

Have you checked that with sample_prior="only" you get samples from the specified prior?

Yes, I fed the function both the entire dataset (10000 rows) and a very small sample (250 rows), and the runtime was no different.

I’m playing around with within-chain parallelization in cmdstanr, and that seems to be helping a little bit.

Are there any NUTS parameters I can change that might lead to faster samples? Perhaps the default settings aren’t working well in this particular instance. I am somewhat familiar with MCMC in general, but know very little about NUTS specifically, so any pointers here would be appreciated.

Aside from the runtime, have you checked that the mean and covariance of the samples you draw from the prior are anything close to the input mean and covariance? Your vcov2 has quite a lot of structure: The input variances are approx. equal, with perhaps some estimation error since they are actually MLEs based on historical data; the input correlations are nonnegative; the matrix has two blocks. multi_normal doesn’t know about this structure, so when I compute the sample covariance – after supposedly sampling from the prior – that sample covariance matrix isn’t anything like the input covariance matrix vcov2. For example, there are a lot of negative correlations in the draws from the prior.

I can’t come up with other ideas to improve the speed of the sampling but what’s the purpose of sampling efficiently when the samples don’t look right.

I’m not so bothered about whether the sample covariance of what I draw exactly matches the actual distribution. For example, if I just draw 5000 samples directly from the distribution with x = rmvnorm(5000, mean2, vcov2), about 12% of the entries in the cov(x) result are negative. That doesn’t concern me as a) the correlations are weak to begin with; and b) I am confident that I am sampling from the desired distribution. There is an interesting question for a probability textbook question there about how tight covariances will be relative to the actual distribution. If I just call rnorm(5000), nobody expects the mean to be exactly 0. Plus, this is just the prior, in this problem, the variances are so wide that it doesn’t take many data points to get the posterior to be nearly entirely data-driven anyway. (The only reason I am making such a big deal of this to begin with is that I would like my results to be sensible with very few data points.)

Back to the original question, I’ve been able to speed up sampling time through methods unrelated to the distribution, by making sacrifices re: number of samples and number of cores I want to use. I’m now able to get half of the post-warmup samples in about a quarter of the overall time (using cmdstanr instead of rstan, within-chain parallelization, using additional cores, etc.). But I would still like to understand why Stan struggles so much with this type of distribution (and ideally how to get it to not struggle), other than “well, it’s just different.”

Thanks for your help!

1 Like