dietary_fat_dataset_long.txt (564 Bytes)
Hi everybody,
This post is a bit lengthy, but I wanted to provide as much context as possible.
I’m trying to port a network meta-analysis from BUGS/JAGS to Stan, with the main motivations being the ability to have pre-compiled models (the model will need to many times on different data sets) and faster sampling than JAGS offers. This has turned out to be a little tricky, particularly I keep getting divergent transitions when using the NUTS algorithm, and getting what I believe is quite inefficient sampling with the HMC sampler.
I’ve seen some other attempts at this around on Discourse, but they all (in my view) stay too loyal to the original WinBUGS models with ragged arrays and a bunch of constant zero-valued parameters, which makes the Stan code a lot more verbose and involved to read.
The univariate “trick” mentioned in the code is that, instead of using a multivariate normal distribution to model within-trial correlation in N-armed trials (N \geq 3), a univariate conditional distribution of the effect of the k’th treatment in the i’th trial (with the trial-specific control arm being k = 1) can be used (the md
vector represents the location parameter):
Although I’m not convinced using this conditional univariate likelihood is advantageous, in this first attempt I wanted to implement the same model as the “industry-standard” to have a point of reference. The model, comparing poisson-distributed event rates, looks like this:
// Inspired by model in: https://discourse.mc-stan.org/t/replication-of-winbugs-nma-code-in-stan/3587/2
data {
int<lower=0> ns; // number of studies
int<lower=0> nt; // number of treatments
int<lower=1> no; // number of observations
int<lower=0> s[no]; // study indicator
int<lower=0> t[no]; // treatment indicator
int<lower=0> b[no]; // baseline indicator
int<lower=0> r[no]; // number of evens
int<lower=1> a[no]; // arm in study
vector<lower=0>[no] E; // exposure time (vector for vectorised sampling)
int<lower=1> b_idx[ns]; // indices of control arms
int<lower=1> t_idx[no-ns]; // indices of active treatment arms
}
transformed data {
vector[no-ns] a_t = to_vector(a[t_idx]); // arms of active treatments only (used in sampling block)
}
parameters {
vector[nt-1] d_param; // mean effect of treatments (except control) against control
vector[ns] mu; // trial-specific baseline mean effect
real<lower=0> sigma; // heterogeneity parameter
vector[no-ns] delta; // comparison-specific treatment effects
}
transformed parameters{
vector[nt] d; // mean effect of treatments (incl. control) against control
vector[no-ns] md; // comparison-specific mean difference
// Stan doesn't allow mixing random and static values
d[1] = 0; // effect of control vs. control is zero by definition
for(i in 2:nt) {
d[i] = d_param[i-1]; // pull in the rest from the parameter of real interest
}
// Univariate trick (eq. 14 [p. 36] in http://nicedsu.org.uk/wp-content/uploads/2017/05/TSD2-General-meta-analysis-corrected-2Sep2016v2.pdf)
{
real weight = 0;
int current_study = 0; // keep track of study to know when to reset multi-arm weight
// FIX: Should probably be vectorised in some nifty way
for (i in 1:(no-ns)) {
int current_t_idx = t_idx[i];
// Reset weight when encountering new study
if (current_study != s[current_t_idx]) {
weight = 0;
current_study = s[current_t_idx];
}
md[i] = d[t[current_t_idx]] - d[b[current_t_idx]];
// SD weighting directly in sampling statement
if (a[current_t_idx] > 2) {
weight += delta[i-1] - d[t[current_t_idx-1]] + d[b[current_t_idx-1]]; // mimick sum over k-1
md[i] += weight / (a[current_t_idx] - 1);
}
}
}
}
model {
// Priors
mu ~ normal(0, 10);
sigma ~ normal(0, 10);
// Likelihoods
r[b_idx] ~ poisson_log(mu[s[b_idx]] + log(E[b_idx]));
r[t_idx] ~ poisson_log(mu[s[t_idx]] + delta + log(E[t_idx]));
delta ~ normal(md, sigma * sqrt(0.5 * a_t ./ (a_t - 1.0)));
}
Running this the following in Rstudio:
for (p in c("rstan", "dplyr"))
library(p, character.only = TRUE)
stan_data <- read.delim("dietary_fat_dataset_long.txt", sep = "\t") %>% # attached to post
mutate(s = as.numeric(factor(study))) %>%
group_by(s) %>%
mutate(a = seq(n())) %>%
c(ns = n_distinct(.$s),
nt = n_distinct(.$t),
no = nrow(.),
b_idx = list(which(.$t == .$b)),
t_idx = list(which(.$t != .$b)))
stan_data
looks like this:
$t
[1] 1 2 1 2 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2
$b
[1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
$E
[1] 1917.0 1925.0 43.6 41.3 38.0 393.5 373.9 4715.0 4823.0 715.0 751.0 885.0 895.0 87.8 91.0 1011.0 939.0 1544.0 1588.0 125.0 123.0
$r
[1] 113 111 1 5 3 24 20 248 269 31 28 65 48 3 1 28 39 177 174 2 1
$s
[1] 1 1 2 2 2 3 3 4 4 5 5 6 6 7 7 8 8 9 9 10 10
$a
[1] 1 2 1 2 3 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2
$ns
[1] 10
$nt
[1] 2
$no
[1] 21
$b_idx
[1] 1 3 6 8 10 12 14 16 18 20
$t_idx
[1] 2 4 5 7 9 11 13 15 17 19 21
When running (note I set algorithm = "HMC"
, sampling takes about 10 seconds on my macbook pro i7)
options(mc.cores = parallel::detectCores() - 1)
rstan_options(auto_write = TRUE)
fit <- stan(file = "poisson_mtc_re_model_with_univariate_trick_lean1.stan",
data = stan_data, seed = 42, algorithm = "HMC", iter = 10000, chains = 7)
Stan throws these three warnings at me:
1: The largest R-hat is NA, indicating chains have not mixed.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#r-hat
2: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#bulk-ess
3: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#tail-ess
If I try algorithm = "NUTS"
, I get 1177 divergent transitions.
print(fit, probs = c(0.5, 0.025, 0.975))
yields (omitting the “header” and “footer”):
mean se_mean sd 2.5% 50% 97.5% n_eff Rhat
d_param[1] -0.02 0.00 0.09 -0.20 -0.02 0.16 26377 1
mu[1] -2.83 0.00 0.08 -3.00 -2.83 -2.68 9707 1
mu[2] -2.70 0.00 0.37 -3.48 -2.67 -2.04 11642 1
mu[3] -2.86 0.00 0.17 -3.19 -2.85 -2.54 9944 1
mu[4] -2.93 0.00 0.06 -3.05 -2.93 -2.82 9194 1
mu[5] -3.20 0.00 0.15 -3.50 -3.20 -2.91 8177 1
mu[6] -2.71 0.00 0.11 -2.93 -2.71 -2.48 9649 1
mu[7] -3.89 0.01 0.53 -5.04 -3.86 -2.98 9331 1
mu[8] -3.43 0.00 0.15 -3.74 -3.42 -3.16 6905 1
mu[9] -2.18 0.00 0.07 -2.31 -2.18 -2.04 7781 1
mu[10] -4.57 0.01 0.64 -6.01 -4.51 -3.49 3832 1
sigma 0.15 0.00 0.12 0.02 0.11 0.44 2466 1
delta[1] -0.02 0.00 0.10 -0.22 -0.02 0.17 11073 1
delta[2] 0.04 0.00 0.21 -0.29 0.01 0.57 12659 1
delta[3] 0.02 0.00 0.20 -0.34 0.00 0.48 20373 1
delta[4] -0.04 0.00 0.15 -0.37 -0.03 0.25 14772 1
delta[5] 0.03 0.00 0.08 -0.12 0.02 0.19 8293 1
delta[6] -0.05 0.00 0.14 -0.36 -0.03 0.22 12595 1
delta[7] -0.11 0.00 0.14 -0.44 -0.09 0.11 5187 1
delta[8] -0.05 0.00 0.20 -0.53 -0.03 0.31 21202 1
delta[9] 0.09 0.00 0.16 -0.16 0.06 0.49 5468 1
delta[10] -0.03 0.00 0.08 -0.20 -0.03 0.13 10623 1
delta[11] -0.03 0.00 0.20 -0.47 -0.02 0.36 23127 1
d[1] 0.00 NaN 0.00 0.00 0.00 0.00 NaN NaN
d[2] -0.02 0.00 0.09 -0.20 -0.02 0.16 26377 1
md[1] -0.02 0.00 0.09 -0.20 -0.02 0.16 26377 1
md[2] -0.02 0.00 0.09 -0.20 -0.02 0.16 26377 1
md[3] 0.01 0.00 0.13 -0.21 0.00 0.32 16029 1
md[4] -0.02 0.00 0.09 -0.20 -0.02 0.16 26377 1
md[5] -0.02 0.00 0.09 -0.20 -0.02 0.16 26377 1
md[6] -0.02 0.00 0.09 -0.20 -0.02 0.16 26377 1
md[7] -0.02 0.00 0.09 -0.20 -0.02 0.16 26377 1
md[8] -0.02 0.00 0.09 -0.20 -0.02 0.16 26377 1
md[9] -0.02 0.00 0.09 -0.20 -0.02 0.16 26377 1
md[10] -0.02 0.00 0.09 -0.20 -0.02 0.16 26377 1
md[11] -0.02 0.00 0.09 -0.20 -0.02 0.16 26377 1
lp__ 5403.67 0.21 7.65 5389.42 5403.20 5418.95 1346 1
stan_dens
and stan_trace
don’t suggest problems with mixing, just as Rhat above is 1 for all parameters.
In essence, with this model I get the same results as the technical report (http://nicedsu.org.uk/wp-content/uploads/2017/05/TSD2-General-meta-analysis-corrected-2Sep2016v2.pdf, p. 66), suggesting that the model is equivalent (although, strangely, my sigma (standard deviation) parameter comes out with what seems to be their precision; but might be a typo in their table, or some weird mistake in my model of course).
My concerns/questions are the following:
- Is it expected/are there good reasons why NUTS cannot “find its way” around the posterior, when normal HMC seems to be doing fine. Are there any (more or less) obvious problems with the code that renders it difficult for NUTS? Or is it likely simply a matter of estimating too many parameters with too few observations? Can enforcing fixed values for parameters cause this kind of behaviour?
- I wonder if the warnings regarding ESS is an artefact caused by forcing
d[1]
to zero. -
lp__
seems to have very lown_eff
. Is that a problem? All the other parameters seem fine, althoughn_eff
of a few thousand isn’t impressive when I have 35,000 posterior draws. - Is this model simply ill-defined?
If you’ve made it this far, thanks for reading on! I hope someone has some insights and could perhaps give me a pointer as to what to do from here.
Cheers,
Ben