Dose-response model with partial pooling

I am having troubles fitting a four-parameter log-logistic (LL4) model with partial pooling to dose-response data.

The four-parameter log-logistic function is a function of the dose x, and is given by

f(x) = c + \dfrac{d-c}{1+\exp(b(\log(x)-\hat{e}))}

where the response is measured in counts (hence a Poisson model), and

  • b is the slope parameter,
  • c is the lower asymptote,
  • d is the upper asymptote, and
  • \hat{e} is the logarithm of the IC50 (the half maximal inhibitory dose).

Question

To pre-empt the key question: Fitting the model with partial pooling leads to chains not mixing well, divergent transitions and “maximum treedepth exceeded” warnings. Is there a way to optimise the model to speed up the calculation (it’s very slow for my actual data) and to ensure that results are reliable?

Simple example: No pooling, only one group of measurements

tldr; Fit estimates are reasonable, chains seem to mix well

Model definition

model_code <- "
functions {
  real LL4(real x, real b, real c, real d, real ehat) {
    return c + (d - c) / (1 + exp(b * (log(x) - ehat)));
  }
}

data {
  int<lower=1> N;                           // Measurements
  vector[N] x;                              // Dose values
  int y[N];                                 // Response values for drug
}

// Transformed data
transformed data {
}

// Sampling space
parameters {
  real<lower=0> b;                     // slope
  real<lower=0> c;                     // lower asymptote
  real<lower=0> d;                     // upper asymptote
  real ehat;                           // loge(IC50)
}

// Transform parameters before calculating the posterior
transformed parameters {
  real<lower=0> e;
  e = exp(ehat);
}

// Calculate posterior
model {
  vector[N] mu_y;

  // Priors on parameters
  c ~ normal(0, 10);
  d ~ normal(max(y), 10);
  ehat ~ normal(mean(log(x)), 10);
  b ~ normal(1, 2);

  for (i in 1:N) {
    mu_y[i] = LL4(x[i], b, c, d, ehat);
  }
  y ~ poisson(mu_y);

}
"

I chose weakly informative normal priors; for example, the prior on \hat{e} is a normal density centered at the mean of the log-transformed doses, which assumes that measurements were chosen such that the IC50 (i.e. \exp(\hat{e})) is located somewhere between the min and max dose.

Data

data_stan <- list(N = 11L, x = c(2e-07, 1e-07, 5e-08, 2.5e-08, 1.25e-08, 6.25e-09,
3.13e-09, 1.56e-09, 7.79e-10, 3.9e-10, 1.95e-10), y = c(13L,
9L, 85L, 149L, 183L, 716L, 2600L, 3472L, 3438L, 3475L, 3343L))

Fit

library(rstan)
fit <- stan(model_code = model_code, data = data_stan)

Fit with default parameters for stan.

Results

fit
#Inference for Stan model: 021d56ecc88759eaa829c4c0033a50f5.
#4 chains, each with iter=2000; warmup=1000; thin=1;
#post-warmup draws per chain=1000, total post-warmup draws=4000.
#
#          mean se_mean   sd      2.5%       25%       50%       75%     97.5%
#b         3.04    0.00 0.10      2.84      2.97      3.04      3.11      3.24
#c        48.80    0.08 3.80     41.55     46.21     48.74     51.34     56.40
#d      3477.13    0.14 9.63   3458.73   3470.60   3477.02   3483.81   3495.99
#ehat    -19.30    0.00 0.01    -19.32    -19.30    -19.30    -19.29    -19.27
#e         0.00    0.00 0.00      0.00      0.00      0.00      0.00      0.00
#lp__ 121443.74    0.03 1.38 121440.23 121443.05 121444.05 121444.78 121445.47
#     n_eff Rhat
#b     2391    1
#c     2386    1
#d     4782    1
#ehat  2840    1
#e     2842    1
#lp__  2103    1
#
#Samples were drawn using NUTS(diag_e) at Tue Mar 24 14:14:48 2020.
#For each parameter, n_eff is a crude measure of effective sample size,
#and Rhat is the potential scale reduction factor on split chains (at
#convergence, Rhat=1).

Parameter estimates are reasonable and they chains seem to have mixed well.

Model with partial pooling: Multiple groups

tldr; Chains don’t seem to mix well;

Model definition

model_code <- "
functions {
  real LL4(real x, real b, real c, real d, real ehat) {
    return c + (d - c) / (1 + exp(b * (log(x) - ehat)));
  }
}

data {
  int<lower=1> N;                           // Measurements
  int<lower=1> J;                           // Number of drugs
  int<lower=1,upper=J> drug[N];             // Drug of measurement
  vector[N] x;                              // Dose values
  int y[N];                                 // Response values for drug
}

// Transformed data
transformed data {
}

// Sampling space
parameters {
  vector<lower=0>[J] b;                     // slope
  vector<lower=0>[J] c;                     // lower asymptote
  vector<lower=0>[J] d;                     // upper asymptote
  vector[J] ehat;                           // loge(IC50)

  // Hyperparameters
  real<lower=0> mu_c;
  real<lower=0> sigma_c;
  real<lower=0> mu_d;
  real<lower=0> sigma_d;
  real mu_ehat;
  real<lower=0> sigma_ehat;
  real<lower=0> mu_b;
  real<lower=0> sigma_b;
}

// Transform parameters before calculating the posterior
transformed parameters {
  vector<lower=0>[J] e;
  e = exp(ehat);
}

// Calculate posterior
model {
  vector[N] mu_y;

  // Parameters parameters
  c ~ normal(mu_c, sigma_c);
  d ~ normal(mu_d, sigma_d);
  ehat ~ normal(mu_ehat, sigma_ehat);
  b ~ normal(mu_b, sigma_b);

  // Priors for hyperparameters
  mu_c ~ normal(0, 10);
  sigma_c ~ cauchy(0, 2.5);
  mu_d ~ normal(max(y), 10);
  sigma_d ~ cauchy(0, 2.5);
  mu_ehat ~ normal(mean(log(x)), 10);
  sigma_ehat ~ cauchy(0, 2.5);
  mu_b ~ normal(1, 2);
  sigma_b ~ cauchy(0, 2.5);


  for (i in 1:N) {
    mu_y[i] = LL4(x[i], b[drug[i]], c[drug[i]], d[drug[i]], ehat[drug[i]]);
  }
  y ~ poisson(mu_y);

}
"

Data

data_stan <- list(N = 33L, J = 3L, drug = c(1L, 2L, 3L, 1L, 2L, 3L, 1L, 2L,
3L, 1L, 2L, 3L, 1L, 2L, 3L, 1L, 2L, 3L, 1L, 2L, 3L, 1L, 2L, 3L,
1L, 2L, 3L, 1L, 2L, 3L, 1L, 2L, 3L), x = c(2e-07, 2e-07, 2e-07,
1e-07, 1e-07, 1e-07, 5e-08, 5e-08, 5e-08, 2.5e-08, 2.5e-08, 2.5e-08,
1.25e-08, 1.25e-08, 1.25e-08, 6.25e-09, 6.25e-09, 6.25e-09, 3.13e-09,
3.13e-09, 3.13e-09, 1.56e-09, 1.56e-09, 1.56e-09, 7.79e-10, 7.79e-10,
7.79e-10, 3.9e-10, 3.9e-10, 3.9e-10, 1.95e-10, 1.95e-10, 1.95e-10
), y = c(13L, 13L, 9L, 9L, 18L, 27L, 85L, 50L, 37L, 149L, 119L,
147L, 183L, 38L, 167L, 716L, 375L, 585L, 2600L, 763L, 997L, 3472L,
1288L, 2013L, 3438L, 1563L, 2334L, 3475L, 2092L, 2262L, 3343L,
2032L, 2575L))

Fit

library(rstan)
fit <- stan(model_code = model_code, data = data_stan)

Results

fit
Inference for Stan model: 3030bb3e5df67b1b6634384b49859e66.
4 chains, each with iter=2000; warmup=1000; thin=1;
post-warmup draws per chain=1000, total post-warmup draws=4000.

                mean se_mean     sd      2.5%       25%       50%       75%
b[1]            3.11    0.00   0.12      2.89      3.04      3.11      3.19
b[2]            1.50    0.00   0.06      1.38      1.46      1.50      1.54
b[3]            1.57    0.00   0.05      1.48      1.54      1.57      1.60
c[1]           56.01    0.09   4.55     47.24     53.05     55.95     59.00
c[2]           19.05    0.07   3.87     11.65     16.33     18.97     21.65
c[3]           14.86    0.07   3.78      7.86     12.20     14.75     17.40
d[1]         3490.82    0.61  30.44   3430.73   3469.63   3491.34   3511.80
d[2]         2131.86    0.99  45.78   2043.71   2101.47   2129.82   2162.24
d[3]         2579.64    0.99  39.66   2501.46   2552.40   2579.29   2606.35
ehat[1]       -19.30    0.00   0.02    -19.33    -19.31    -19.30    -19.29
ehat[2]       -20.04    0.00   0.05    -20.14    -20.08    -20.04    -20.01
ehat[3]       -19.74    0.00   0.04    -19.80    -19.76    -19.74    -19.71
mu_c           12.40    0.15   7.36      0.55      6.46     11.86     17.67
sigma_c        28.74    0.39  19.12     11.05     18.77     24.59     33.39
mu_d         3474.56    0.12   9.95   3454.82   3467.89   3474.69   3481.47
sigma_d       985.14   10.31 455.20    487.32    689.04    869.30   1150.68
mu_ehat       -19.70    0.02   0.70    -21.05    -19.94    -19.69    -19.44
sigma_ehat      0.90    0.02   0.87      0.23      0.41      0.63      1.07
mu_b            1.96    0.02   0.70      0.52      1.55      1.97      2.35
sigma_b         1.43    0.02   0.93      0.51      0.86      1.19      1.69
e[1]            0.00    0.00   0.00      0.00      0.00      0.00      0.00
e[2]            0.00    0.00   0.00      0.00      0.00      0.00      0.00
e[3]            0.00    0.00   0.00      0.00      0.00      0.00      0.00
lp__       245804.20    0.10   3.47 245796.42 245802.16 245804.59 245806.67
               97.5% n_eff Rhat
b[1]            3.36  1986 1.00
b[2]            1.62  1979 1.00
b[3]            1.67  1620 1.00
c[1]           65.15  2807 1.00
c[2]           26.67  2791 1.00
c[3]           22.47  2823 1.00
d[1]         3552.15  2460 1.00
d[2]         2224.56  2155 1.00
d[3]         2657.44  1613 1.00
ehat[1]       -19.27  2174 1.00
ehat[2]       -19.95  2115 1.00
ehat[3]       -19.67  1592 1.00
mu_c           27.40  2501 1.00
sigma_c        72.80  2385 1.00
mu_d         3493.68  6787 1.00
sigma_d      2153.57  1950 1.00
mu_ehat       -18.33  1684 1.00
sigma_ehat      3.26  1328 1.00
mu_b            3.47  2065 1.00
sigma_b         4.01  1823 1.00
e[1]            0.00  2176 1.00
e[2]            0.00  2111 1.00
e[3]            0.00  1585 1.00
lp__       245809.96  1189 1.01

Samples were drawn using NUTS(diag_e) at Tue Mar 24 14:43:38 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).

Chains don’t seem to mix well, and I get warnings about insufficient treedepth and/or divergent transitions.

I know that parameter estimates are correlated, and pairs confirms this. This will always be the case as a change in the slope can be “compensated” (to some degree) by changes in the other parameters.

My question is: What can I do to improve the quality and speed of the model fit? The example here has three groups, my actual data has measurements for close to 1000 groups. Fitting the model with partial pooling becomes very slow in that case. I could fit individual models per group, but that would mean forfeiting the advantages that come with partially pooling parameter estimates across all groups.

Based on other posts and general Stan fitting advice I’ve read, I’ve played around with

  1. increasing the number of iterations,
  2. increasing the max_treedepth, and
  3. increasing the adapt_delta

with moderate success.

My first guess is that you’re running into a funnel problem. If that’s the case you might be able to solve it with a non-centered reparameterization.

What do the pairs plots look like if you do one variable at a time e.g.

pairs(fit, pars = c("c", "mu_c", "sigma_c"))?

Does it look funnel like?

I know that parameter estimates are correlated, and pairs confirms this. This will always be the case as a change in the slope can be “compensated” (to some degree) by changes in the other parameters.

I’d be curious to see which variables are correlated in the pairs plot you already looked at.

Hi @arya. Thanks for the quick response. While I read up more on the funnel problem you mention, here are the pairs plots:

pairs(fit, pars = c("b", "c", "d", "ehat"))

You can see how e.g. b[2] is correlated with c[2], d[2] and ehat[2]. There is some considerable correlation between some variables, see e.g. d[2] and ehat[2]. I think this will always be the case. From what I read in related posts here on the Stan Forum, the suggestion was to increase the number of iterations.


As for

pairs(fit, pars = c("c", "mu_c", "sigma_c"))

Not really sure if this is funnel shaped. What would you say?

It’s hard to tell without looking at sigma_c on the log-scale, but possibly. It looks like there’s an interesting correlation between mu_c and sigma_c. When sigma_c is big, mu_c seems to get smaller.

I wonder if there’s an issue going on with using normal priors on positive parameters without truncating the distribution. It’s pretty easy to code up truncated distributions in Stan, but I’m not sure if that’s the issue.

I was under the impression that I had already used truncated distributions; e.g. for c I have

vector<lower=0>[J] c;
real<lower=0> mu_c;
real<lower=0> sigma_c;
c ~ normal(mu_c, sigma_c);
mu_c ~ normal(0, 10);
sigma_c ~ cauchy(0, 2.5);

Since c has a lower bound of 0, normal(mu_c, sigma_c) should correspond to a truncated normal, no?

Do you think it might make more sense to have a Poisson prior on c and d, since they correspond to asymptotic counts and y ~ poisson(mu_y)? That would reduce the number of hyperparameters. Perhaps this is the issue here?

Following @arya’s suggestion, I changed model_code to implement non-centred parametrisation (as discussed in the Stan user guide).

model_code <- "
functions {
  real LL4(real x, real b, real c, real d, real ehat) {
    return c + (d - c) / (1 + exp(b * (log(x) - ehat)));
  }
}

data {
  int<lower=1> N;                           // Measurements
  int<lower=1> J;                           // Number of drugs
  int<lower=1,upper=J> drug[N];             // Drug of measurement
  vector[N] x;                              // Dose values
  int y[N];                                 // Response values for drug
}

// Transformed data
transformed data {
}

// Sampling space
parameters {
  vector<lower=0>[J] b_raw;            // slope
  vector<lower=0>[J] c_raw;            // lower asymptote
  vector<lower=0>[J] d_raw;            // upper asymptote
  vector[J] ehat_raw;                  // loge(IC50)

  // Hyperparameters
  real<lower=0> mu_c;
  real<lower=0> sigma_c;
  real<lower=0> mu_d;
  real<lower=0> sigma_d;
  real mu_ehat;
  real<lower=0> sigma_ehat;
  real<lower=0> mu_b;
  real<lower=0> sigma_b;
}

// Transform parameters before calculating the posterior
// Use non-centred parametrisation of b, c, d, ehat
transformed parameters {
  vector<lower=0>[J] b;
  vector<lower=0>[J] c;
  vector<lower=0>[J] d;
  vector[J] ehat;
  vector<lower=0>[J] e;

  b = mu_b + sigma_b * b_raw;
  c = mu_c + sigma_c * c_raw;
  d = mu_d + sigma_d * d_raw;
  ehat = mu_ehat + sigma_ehat * ehat_raw;

  e = exp(ehat);
}

// Calculate posterior
model {
  vector[N] mu_y;

  // Parameters
  b_raw ~ std_normal();
  c_raw ~ std_normal();
  d_raw ~ std_normal();
  ehat_raw ~ std_normal();

  // Priors for hyperparameters
  mu_c ~ normal(0, 10);
  sigma_c ~ cauchy(0, 2.5);
  mu_d ~ normal(max(y), 10);
  sigma_d ~ cauchy(0, 2.5);
  mu_ehat ~ normal(mean(log(x)), 10);
  sigma_ehat ~ cauchy(0, 2.5);
  mu_b ~ normal(1, 2);
  sigma_b ~ cauchy(0, 2.5);

  for (i in 1:N) {
    mu_y[i] = LL4(x[i], b[drug[i]], c[drug[i]], d[drug[i]], ehat[drug[i]]);
  }
  y ~ poisson(mu_y);

}
"

Run times are similar, as are estimates. I still end up with (what I think are) a significant amount of divergent transitions (somewhere between 100s to 1000s).

I’m not entirely sure I correctly implemented the non-centred parametrisation, in particular in light of the parameter bounds.

Blockquote

Usually for parameters that are constrained to be positive I would go for a log-normal, a chi-square, a gamma, or an inverse-gamma.

As for the truncated densities, you need to divide by the area under the curve to get the correct density. So if y is constrained to be positive then a truncated N(0,1) density should be the density of a N(0,1) which we’ll call N(y | 0,1), divided by the normalization term

\int_0^\infty N(y | 0,1) dy.

Since the support is smaller for a truncated normal, i.e. [0, \infty) instead of (-\infty, \infty) this normalization term is there to ensure the density integrates to one and is thus a proper density.

Ok, I will look into this. I’m not convinced that this will make a significant difference though. On that note I also realised that I cannot use Poisson priors on c and d as we cannot have integer arrays as parameters.

Yes, I agree with the mathematical details but I’m not sure this is relevant in Stan. From what I understand, setting a lower bound on the parameters ensures the boundedness of the density. For example, from Prior Choice Recommendations

[…] half-normal(0,10) (that’s implemented as normal(0,10) with a <lower=0> constraint in the declaration of the parameter) […]

No need to manually normalise densities.

An update

Log-normal densities for 0-bounded (count-like) parameters

Having lognormal densities for c and d significantly improved the runtime (thanks @arya for the suggestion).

The model now looks like this

functions {
  real LL4(real x, real b, real c, real d, real ehat) {
    return c + (d - c) / (1 + exp(b * (log(x) - ehat)));
  }
}

data {
  int<lower=1> N;                           // Measurements
  int<lower=1> J;                           // Number of drugs
  int<lower=1,upper=J> drug[N];             // Drug of measurement
  vector[N] x;                              // Dose values
  int y[N];                                 // Response values for drug
}

// Transformed data
transformed data {
}

// Sampling space
// We need to impose parameter constraints on the LL4 parameters:
// Since c,d ~ LogNormal(), we need <lower=0> for c,d
parameters {
  vector<lower=0>[J] b;                     // slope
  vector<lower=0>[J] c;                     // lower asymptote
  vector<lower=0>[J] d;                     // upper asymptote
  vector[J] ehat;                           // loge(IC50)

  // Hyperparameters
  real<lower=0> mu_b;
  real<lower=0> sigma_b;
  real mu_c;
  real<lower=0> sigma_c;
  real mu_d;
  real<lower=0> sigma_d;
  real mu_ehat;
  real<lower=0> sigma_ehat;
}

// Transform parameters before calculating the posterior
transformed parameters {
  vector<lower=0>[J] e;
  real log_sigma_b;
  real log_sigma_c;
  real log_sigma_d;
  real log_sigma_ehat;

  // IC50
  e = exp(ehat);
  
  log_sigma_b = log10(sigma_b);
  log_sigma_c = log10(sigma_c);
  log_sigma_d = log10(sigma_d);
  log_sigma_ehat = log10(sigma_ehat);
}

// Calculate posterior
model {
  // Declare mu_y in model to make it local (i.e. we don't want mu_y to show
  // up in the output)
  vector[N] mu_y;

  // Parameter priors
  c ~ lognormal(mu_c, sigma_c);
  d ~ lognormal(mu_d, sigma_d);
  ehat ~ normal(mu_ehat, sigma_ehat);
  b ~ normal(mu_b, sigma_b);

  // Priors for hyperparameters
  mu_c ~ normal(0, 5);
  sigma_c ~ cauchy(0, 2.5);
  mu_d ~ normal(max(y), 5);
  sigma_d ~ cauchy(0, 2.5);
  mu_ehat ~ normal(mean(log(x)), 5);
  sigma_ehat ~ cauchy(0, 2.5);
  mu_b ~ normal(1, 2);
  sigma_b ~ cauchy(0, 2.5);

  for (i in 1:N) {
    mu_y[i] = LL4(x[i], b[drug[i]], c[drug[i]], d[drug[i]], ehat[drug[i]]);
  }
  y ~ poisson(mu_y);
}

Correlation between parameters and their hyperparameters

There is still a potential issue with the correlation between the LL4 parameters and their hyperparameters. For example, pairs(fit, pars = c("b", "mu_b", "log_sigma_b")) gives

There is a clear correlation structure between mu_b and the log-transformed sigma_b that looks funnel-shaped.

Question: Is this something to worry about? And if so, what steps can I take to avoid this from happening.

Glad to hear the lognormal worked out. I wonder if it could’ve been the problem with normalizing the truncated normal.

Yeah that’s definitely a funnel but if you’re not getting divergences and otherwise good sampler diagnostics, you should be ok.

I notice that using the original pooled model with normal priors, sigma_d is estimated to be 985 but its prior is cauchy(0, 2.5). If you shift that prior to have more amplitude near the estimated value, does the performance improve?

Thanks @FJCC for pointing that out. In fact, I had a closer look at the d estimates (including hyperparameters), and here are the results

summary(fit, pars = c("d", "mu_d", "sigma_d"))$summary
#            mean    se_mean          sd     2.5%      25%      50%      75%
#d[1]    3489.865  0.5093660   31.741119 3427.630 3467.723 3490.117 3511.441
#d[2]    2127.745  1.1523709   46.468022 2038.463 2096.270 2127.392 2158.768
#d[3]    2576.638  0.7579097   38.674776 2501.041 2549.911 2575.692 2602.964
#mu_d    3474.843  0.0646205    4.871406 3465.109 3471.562 3474.840 3478.181
#sigma_d 3779.928 40.9239714 1960.299861 1791.430 2589.539 3245.744 4337.591
#           97.5%    n_eff      Rhat
#d[1]    3550.748 3883.154 0.9997746
#d[2]    2221.764 1626.012 1.0003123
#d[3]    2653.945 2603.878 1.0004396
#mu_d    3484.382 5682.874 0.9996138
#sigma_d 8859.847 2294.507 1.0022283

Note that sigma_d is as large as mu_d itself. I find that odd, considering that the measured max of the responses per group (which is what d should roughly correspond to) are

by(data_stan$y, data_stan$drug, FUN = max)
#data_stan$drug: 1
#[1] 3475
#------------------------------------------------------------
#data_stan$drug: 2
#[1] 2092
#------------------------------------------------------------
#data_stan$drug: 3
#[1] 2575

Ok, d[1], d[2] and d[3] agree with these numbers. But the mean and sd of these numbers are 2714 and 701, respectively. I’m very confused as to why the mu_d and sigma_d estimates are so large? That doesn’t seem to make sense to me. The whole point of pooling should have been to shrink estimates towards the overall mean.

First, a disclaimer - I am just a novice Stan user trying to learn by participating in this forum. Don’t take anything I say as authoritative.

I am unsure which model produced your last results but in any case, I think of mu_d as being estimated from only three values. d is the upper asymptote of a drug and you only have three of those. In any context, given the three numbers 2100, 2600, and 3500, you can’t say much about the distribution of the population that produced the values other than it is typically in the low thousands, kind of, probably. What really strikes me is the very tight estimate on mu_d coupled with the huge standard deviation on sigma_d. Is the prior on mu_d still normal(max(y), 5)? That is really tight and matches the posterior estimate. That might drive sigma_d to larger values so that d[2] down at 2100 is reasonably probable. That isn’t expensive in terms of the total probability because there is only one value up at 3500 that loses probability as the sigma increases. If you do have the tight prior on mu_d, try making it much broader and not pinned to the global maximum and see how the result changes. Maybe normal(mean(y), 750). Though, again, I am unsure which model you are using so that suggestion might not make sense.
Edit: I wrote the wrong thing with normal(mean(y), 750). What I meant by mean(y) there is mean(y_max), the mean of the three maximum y values, one for each drug.