I am having troubles fitting a four-parameter log-logistic (LL4) model with partial pooling to dose-response data.
The four-parameter log-logistic function is a function of the dose x, and is given by
f(x) = c + \dfrac{d-c}{1+\exp(b(\log(x)-\hat{e}))}
where the response is measured in counts (hence a Poisson model), and
- b is the slope parameter,
- c is the lower asymptote,
- d is the upper asymptote, and
- \hat{e} is the logarithm of the IC50 (the half maximal inhibitory dose).
Question
To pre-empt the key question: Fitting the model with partial pooling leads to chains not mixing well, divergent transitions and “maximum treedepth exceeded” warnings. Is there a way to optimise the model to speed up the calculation (it’s very slow for my actual data) and to ensure that results are reliable?
Simple example: No pooling, only one group of measurements
tldr; Fit estimates are reasonable, chains seem to mix well
Model definition
model_code <- "
functions {
real LL4(real x, real b, real c, real d, real ehat) {
return c + (d - c) / (1 + exp(b * (log(x) - ehat)));
}
}
data {
int<lower=1> N; // Measurements
vector[N] x; // Dose values
int y[N]; // Response values for drug
}
// Transformed data
transformed data {
}
// Sampling space
parameters {
real<lower=0> b; // slope
real<lower=0> c; // lower asymptote
real<lower=0> d; // upper asymptote
real ehat; // loge(IC50)
}
// Transform parameters before calculating the posterior
transformed parameters {
real<lower=0> e;
e = exp(ehat);
}
// Calculate posterior
model {
vector[N] mu_y;
// Priors on parameters
c ~ normal(0, 10);
d ~ normal(max(y), 10);
ehat ~ normal(mean(log(x)), 10);
b ~ normal(1, 2);
for (i in 1:N) {
mu_y[i] = LL4(x[i], b, c, d, ehat);
}
y ~ poisson(mu_y);
}
"
I chose weakly informative normal priors; for example, the prior on \hat{e} is a normal density centered at the mean of the log-transformed doses, which assumes that measurements were chosen such that the IC50 (i.e. \exp(\hat{e})) is located somewhere between the min and max dose.
Data
data_stan <- list(N = 11L, x = c(2e-07, 1e-07, 5e-08, 2.5e-08, 1.25e-08, 6.25e-09,
3.13e-09, 1.56e-09, 7.79e-10, 3.9e-10, 1.95e-10), y = c(13L,
9L, 85L, 149L, 183L, 716L, 2600L, 3472L, 3438L, 3475L, 3343L))
Fit
library(rstan)
fit <- stan(model_code = model_code, data = data_stan)
Fit with default parameters for stan
.
Results
fit
#Inference for Stan model: 021d56ecc88759eaa829c4c0033a50f5.
#4 chains, each with iter=2000; warmup=1000; thin=1;
#post-warmup draws per chain=1000, total post-warmup draws=4000.
#
# mean se_mean sd 2.5% 25% 50% 75% 97.5%
#b 3.04 0.00 0.10 2.84 2.97 3.04 3.11 3.24
#c 48.80 0.08 3.80 41.55 46.21 48.74 51.34 56.40
#d 3477.13 0.14 9.63 3458.73 3470.60 3477.02 3483.81 3495.99
#ehat -19.30 0.00 0.01 -19.32 -19.30 -19.30 -19.29 -19.27
#e 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00
#lp__ 121443.74 0.03 1.38 121440.23 121443.05 121444.05 121444.78 121445.47
# n_eff Rhat
#b 2391 1
#c 2386 1
#d 4782 1
#ehat 2840 1
#e 2842 1
#lp__ 2103 1
#
#Samples were drawn using NUTS(diag_e) at Tue Mar 24 14:14:48 2020.
#For each parameter, n_eff is a crude measure of effective sample size,
#and Rhat is the potential scale reduction factor on split chains (at
#convergence, Rhat=1).
Parameter estimates are reasonable and they chains seem to have mixed well.
Model with partial pooling: Multiple groups
tldr; Chains don’t seem to mix well;
Model definition
model_code <- "
functions {
real LL4(real x, real b, real c, real d, real ehat) {
return c + (d - c) / (1 + exp(b * (log(x) - ehat)));
}
}
data {
int<lower=1> N; // Measurements
int<lower=1> J; // Number of drugs
int<lower=1,upper=J> drug[N]; // Drug of measurement
vector[N] x; // Dose values
int y[N]; // Response values for drug
}
// Transformed data
transformed data {
}
// Sampling space
parameters {
vector<lower=0>[J] b; // slope
vector<lower=0>[J] c; // lower asymptote
vector<lower=0>[J] d; // upper asymptote
vector[J] ehat; // loge(IC50)
// Hyperparameters
real<lower=0> mu_c;
real<lower=0> sigma_c;
real<lower=0> mu_d;
real<lower=0> sigma_d;
real mu_ehat;
real<lower=0> sigma_ehat;
real<lower=0> mu_b;
real<lower=0> sigma_b;
}
// Transform parameters before calculating the posterior
transformed parameters {
vector<lower=0>[J] e;
e = exp(ehat);
}
// Calculate posterior
model {
vector[N] mu_y;
// Parameters parameters
c ~ normal(mu_c, sigma_c);
d ~ normal(mu_d, sigma_d);
ehat ~ normal(mu_ehat, sigma_ehat);
b ~ normal(mu_b, sigma_b);
// Priors for hyperparameters
mu_c ~ normal(0, 10);
sigma_c ~ cauchy(0, 2.5);
mu_d ~ normal(max(y), 10);
sigma_d ~ cauchy(0, 2.5);
mu_ehat ~ normal(mean(log(x)), 10);
sigma_ehat ~ cauchy(0, 2.5);
mu_b ~ normal(1, 2);
sigma_b ~ cauchy(0, 2.5);
for (i in 1:N) {
mu_y[i] = LL4(x[i], b[drug[i]], c[drug[i]], d[drug[i]], ehat[drug[i]]);
}
y ~ poisson(mu_y);
}
"
Data
data_stan <- list(N = 33L, J = 3L, drug = c(1L, 2L, 3L, 1L, 2L, 3L, 1L, 2L,
3L, 1L, 2L, 3L, 1L, 2L, 3L, 1L, 2L, 3L, 1L, 2L, 3L, 1L, 2L, 3L,
1L, 2L, 3L, 1L, 2L, 3L, 1L, 2L, 3L), x = c(2e-07, 2e-07, 2e-07,
1e-07, 1e-07, 1e-07, 5e-08, 5e-08, 5e-08, 2.5e-08, 2.5e-08, 2.5e-08,
1.25e-08, 1.25e-08, 1.25e-08, 6.25e-09, 6.25e-09, 6.25e-09, 3.13e-09,
3.13e-09, 3.13e-09, 1.56e-09, 1.56e-09, 1.56e-09, 7.79e-10, 7.79e-10,
7.79e-10, 3.9e-10, 3.9e-10, 3.9e-10, 1.95e-10, 1.95e-10, 1.95e-10
), y = c(13L, 13L, 9L, 9L, 18L, 27L, 85L, 50L, 37L, 149L, 119L,
147L, 183L, 38L, 167L, 716L, 375L, 585L, 2600L, 763L, 997L, 3472L,
1288L, 2013L, 3438L, 1563L, 2334L, 3475L, 2092L, 2262L, 3343L,
2032L, 2575L))
Fit
library(rstan)
fit <- stan(model_code = model_code, data = data_stan)
Results
fit
Inference for Stan model: 3030bb3e5df67b1b6634384b49859e66.
4 chains, each with iter=2000; warmup=1000; thin=1;
post-warmup draws per chain=1000, total post-warmup draws=4000.
mean se_mean sd 2.5% 25% 50% 75%
b[1] 3.11 0.00 0.12 2.89 3.04 3.11 3.19
b[2] 1.50 0.00 0.06 1.38 1.46 1.50 1.54
b[3] 1.57 0.00 0.05 1.48 1.54 1.57 1.60
c[1] 56.01 0.09 4.55 47.24 53.05 55.95 59.00
c[2] 19.05 0.07 3.87 11.65 16.33 18.97 21.65
c[3] 14.86 0.07 3.78 7.86 12.20 14.75 17.40
d[1] 3490.82 0.61 30.44 3430.73 3469.63 3491.34 3511.80
d[2] 2131.86 0.99 45.78 2043.71 2101.47 2129.82 2162.24
d[3] 2579.64 0.99 39.66 2501.46 2552.40 2579.29 2606.35
ehat[1] -19.30 0.00 0.02 -19.33 -19.31 -19.30 -19.29
ehat[2] -20.04 0.00 0.05 -20.14 -20.08 -20.04 -20.01
ehat[3] -19.74 0.00 0.04 -19.80 -19.76 -19.74 -19.71
mu_c 12.40 0.15 7.36 0.55 6.46 11.86 17.67
sigma_c 28.74 0.39 19.12 11.05 18.77 24.59 33.39
mu_d 3474.56 0.12 9.95 3454.82 3467.89 3474.69 3481.47
sigma_d 985.14 10.31 455.20 487.32 689.04 869.30 1150.68
mu_ehat -19.70 0.02 0.70 -21.05 -19.94 -19.69 -19.44
sigma_ehat 0.90 0.02 0.87 0.23 0.41 0.63 1.07
mu_b 1.96 0.02 0.70 0.52 1.55 1.97 2.35
sigma_b 1.43 0.02 0.93 0.51 0.86 1.19 1.69
e[1] 0.00 0.00 0.00 0.00 0.00 0.00 0.00
e[2] 0.00 0.00 0.00 0.00 0.00 0.00 0.00
e[3] 0.00 0.00 0.00 0.00 0.00 0.00 0.00
lp__ 245804.20 0.10 3.47 245796.42 245802.16 245804.59 245806.67
97.5% n_eff Rhat
b[1] 3.36 1986 1.00
b[2] 1.62 1979 1.00
b[3] 1.67 1620 1.00
c[1] 65.15 2807 1.00
c[2] 26.67 2791 1.00
c[3] 22.47 2823 1.00
d[1] 3552.15 2460 1.00
d[2] 2224.56 2155 1.00
d[3] 2657.44 1613 1.00
ehat[1] -19.27 2174 1.00
ehat[2] -19.95 2115 1.00
ehat[3] -19.67 1592 1.00
mu_c 27.40 2501 1.00
sigma_c 72.80 2385 1.00
mu_d 3493.68 6787 1.00
sigma_d 2153.57 1950 1.00
mu_ehat -18.33 1684 1.00
sigma_ehat 3.26 1328 1.00
mu_b 3.47 2065 1.00
sigma_b 4.01 1823 1.00
e[1] 0.00 2176 1.00
e[2] 0.00 2111 1.00
e[3] 0.00 1585 1.00
lp__ 245809.96 1189 1.01
Samples were drawn using NUTS(diag_e) at Tue Mar 24 14:43:38 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).
Chains don’t seem to mix well, and I get warnings about insufficient treedepth and/or divergent transitions.
I know that parameter estimates are correlated, and pairs
confirms this. This will always be the case as a change in the slope can be “compensated” (to some degree) by changes in the other parameters.
My question is: What can I do to improve the quality and speed of the model fit? The example here has three groups, my actual data has measurements for close to 1000 groups. Fitting the model with partial pooling becomes very slow in that case. I could fit individual models per group, but that would mean forfeiting the advantages that come with partially pooling parameter estimates across all groups.
Based on other posts and general Stan fitting advice I’ve read, I’ve played around with
- increasing the number of iterations,
- increasing the
max_treedepth
, and - increasing the
adapt_delta
with moderate success.