Dose Response Model with partial pooling on maximum value

A minor clarification I forgot to mention - the non-centered parametrization only makes sense when you do partial pooling across multiple curves, it is likely to cause problems when you only have one curve… My intention was that you try the model from the very first post with the non-centered parametrization. Maybe your data actually work even with the default implementation and the main problem was the hierarchical part… I would definitely try that first, before playing with the experiment below.


Anyway, you inspired to give sigmoid another deeper look and I managed to push further one older idea, which seems to work better than the previous one. The main idea is simple. Let’s assume majority of our x values spans roughly between -1 and 1 (but we have a few outside this interval). Having defined the sigmoid function as:

f(x, \beta, \alpha, k) = \frac{k}{1 + e^{-\alpha - \beta x}}

We can introduce three new parameters a,b,c, such that:

a = f(-1,\beta, \alpha, k) \\ b = f(0,\beta, \alpha, k) \\ c = f(1,\beta, \alpha, k) \\

Now a,b,c should be well informed by data - they are the observed values within the range our data covers. Also they should somewhat independent, provided that there are enogh data points between each of our anchor points (the method fails with too few data).

There is a slight issue - not all combinations of a,b,c values define a sigmoid. Obviously, sigmoids are monotonic, so we either need to have 0 < a < b < c or 0 < c < b < a, but is that enough? And how do we solve for \alpha, \beta, k from a,b,c? Now I am not good enough at math to work this out easily by myself. But I’ve enlisted the help of the free version of the Wolfram Cloud (can heavily recommend for doing symbolic math) - my code is here: Wolfram Cloud Document

EDIT: I copied the constraint with a mistake, fixed.

Our friend Wolfram tells us that we have an additional constrain b^2 > ac and that the solution is (regardless whether we have a raising or falling sigmoid):

\alpha = \log{\left(\frac{a}{(a-b)} + \frac{b}{(c-b)}\right)}\\ \beta = \log\frac{a-b}{b-c} + \log{c} - \log{a}\\ k = \frac{ b (a (b - 2 c) + b c}{ b^2 - a c}

So let us assume a lognormal prior on a,b,c to ensure positivity and generate same fake data as a prior predictive check. Note that we can take arbitrary a,c > 0 and then only b is constrained to match the raising/falling shape of the sigmoid and b^2 > ac.

We generate i.i.d lognormal, order them, take a 50% chance of swapping the order. To satisfy the b^2 > ac constraint we do what is fancifully called “rejection sampling” and basically means we discard all draws that do not satisfy the constraint (those tend to be rare):

set.seed(4949393)
N <- 60
x_obs <- seq(-1.5, 1.5, length.out =  100)
sim_data <- tibble(
       # Generate iid log normal
       y1 = rlnorm(N, log(15), log(5)),
       y2 = rlnorm(N, log(15), log(5)),
       y3 = rlnorm(N, log(15), log(5)),
       sigma_g = abs(rnorm(N, 0, 0.5))) %>%
  mutate(a = pmin(y1,y2,y3), c = pmax(y1,y2,y3),  #Order
         b = if_else(a == y1,
                     if_else(c == y3, y2, y3),
                     if_else(a == y2,
                             if_else(c == y3, y1, y3),
                             if_else(c == y2, y1, y2))),
         swap = rbinom(N, 1, 0.5) == 1, #Potentially swap
         a_ = if_else(swap, c, a),
         c = if_else(swap, a, c),
         a = a_
         ) %>%
  select(-swap, a_) %>%
  filter(b^2 > a * c) %>% # Rejection sampling
  mutate(
    draw = 1:n(), 
     alpha = log(a/(a-b) + b / (c-b)), # Compute target parameters
     beta = log((a-b)/(b-c)) + log(c) - log(a),
     k = b * (a * (b - 2 *c) + b * c) / (b^2 - a * c),

  ) %>%    
       
  crossing(tibble(x = x_obs)) %>%
  mutate(
    mu = k/(1+exp(-beta*x - alpha)),
    y = rlnorm(n(), log(mu), sigma_g)) 


sim_data %>%
  filter(abs(mu) < 200) %>%
  ggplot(aes(x = x, y = mu, group = draw)) + geom_line(alpha = 0.5) 

Looks like a wide range of sigmoids is a-priori plausible, so we’re happy for now.

So, now our Stan model. Since we expect positive-only values, I took the liberty to moving from normal noise to lognormal noise as this seems more natural (but you understand your data better, so if this doesn’t make sense to you, keep normal noise). We directly enforce the constraints on b via lower and upper. (note: a,b, c > 0 and b^2 > ac imply b > min\{a,c\}, so we can leave this one out). We need to do a max call, which is potentially problematic - the posterior is still continous, but is not smooth. In this case it seems it is not a big issue as with enough data b is sampled far from the bounds and the unsmoothness is not visited.

EDIT: The previous version had an extra constraint that was unnecessary and caused divergences.

data {
  int<lower=1> N;                           // Number of observations
  vector[N] x;                      // Dose values (air temperature)
  vector[N] y;                     //winter hardiness, *-1 to eb positive
}

// Sampling space
parameters {
  real<lower=0> a;
  real<lower=0> c;
  // Enforce the constraints (yes, this is completely valid code)
  real<lower=sqrt(a * c), upper=max({a,c})> b;
  real<lower=0> sigma_g;  

}

transformed parameters {
  real alpha = log(a/(a-b) + b / (c-b));
  real beta = log((a-b)/(b-c)) + log(c) - log(a);
  real<lower=0> k = b * (a * (b - 2 *c) + b * c) / (b^2 - a * c);
}

// Calculate posterior
model {
  vector[N] logmu_y;

  // priors
  sigma_g ~ normal(0, 0.5);           
  a ~ lognormal(log(15), log(5));            
  b ~ lognormal(log(15), log(5));           
  c ~ lognormal(log(15), log(5)); 

  //likelihood function 
  for (i in 1:N) {
    //Staying on log scale, the line below is equialent to 
    // mu_y[i] = k/(1+exp(-beta*x[i] - alpha));
    logmu_y[i] = log(k) - log1p_exp(-beta*x[i] - alpha);
  }
  y ~ lognormal(logmu_y, sigma_g); 
}

generated quantities {
  
  // Reversing Stan's constraints for diagnostics
  real b_raw;
  
  // Simulate model configuration from prior model (get mu_y)
  vector[N] logmu_y;                      // Simulated mean data from likelyhood 
  vector[N] y_sim;                           //Simulated Data

  //likelihood function 
  for (i in 1:N) {
    logmu_y[i] = log(k) - log1p_exp(-beta*x[i] - alpha);
  }
  // Simulate data from observational model
  for (n in 1:N) y_sim[n] = lognormal_rng(logmu_y[n], sigma_g); 
  
  {
    // Recompute the bounds and reverse the constraint transform
    real b_low = max({min({a,c}), sqrt(a * c)});
    real b_up = max({a,c});
    b_raw = logit((b - b_low) / b_up);
  }
}

And let’s fit it to some of the data we simualted:

set.seed(5949445)
draw_to_use <- 1

N_x <- 15
data_filtered <- sim_data %>% filter(draw == draw_to_use) %>% sample_n(N_x)
data_for_plot_true <- data_filtered %>%
  pivot_longer(c("mu","y"), names_to = "name", values_to = "value")

data_for_stan <- list(N = N_x, x = data_filtered$x, y = data_filtered$y)
fit <- sampling(model_mod, data = data_for_stan)

pars_of_interest <- c("a", "b_raw","c", "sigma_g", "lp__")

summary(fit, pars = pars_of_interest)$summary
summary(fit, pars = c("alpha","beta","k"))$summary

bayesplot::mcmc_pairs(fit, pars = pars_of_interest, transformations = list(a = "log", c = "log"), 
                      np = bayesplot::nuts_params(fit))

gather_draws(fit, logmu_y[i], n = 50) %>% inner_join(tibble(i = 1:N_x, x = data_for_stan$x), by = "i") %>%  ggplot(aes(x = x, y = exp(.value), group = .draw)) + geom_line(alpha = 0.2) + geom_line(aes(x = x, y = value, color = name), data = data_for_plot_true, inherit.aes = FALSE, size = 2)

Recovers neatly:

The pairs plot is not perfect, but the strong correlations are gone:

Also note that while we have a pretty good idea what a,b,c are, we can’t really constrain k:

              mean     se_mean        sd        2.5%         25%         50%         75%       97.5%
a         4.2811747 0.020926543  1.144839   2.5241839   3.4879422   4.1238347   4.8887739   6.9103971
b_raw    -1.3489155 0.021847439  1.000944  -3.5588657  -1.9071577  -1.2759029  -0.6997127   0.3779332
c       214.7881633 1.652218995 73.590397 105.9263605 162.3593716 203.4160025 254.7649242 387.6245267
sigma_g   0.6633307 0.002493904  0.125718   0.4671061   0.5708108   0.6487263   0.7355805   0.9519432
lp__     -6.0126353 0.040435991  1.571293  -9.8941125  -6.7714416  -5.6472307  -4.8603798  -4.0675430
           n_eff      Rhat
a       2992.914 1.0007751
b_raw   2099.028 0.9997686
c       1983.842 0.9992487
sigma_g 2541.179 1.0004543
lp__    1510.004 1.0011529

             mean      se_mean           sd       2.5%        25%         50%        75%      97.5%    n_eff
alpha  -0.6950607   0.02186324 9.810327e-01  -2.708883  -1.257854  -0.6879693  -0.109265   1.175497 2013.437
beta    3.3406841   0.01211739 6.491872e-01   2.190443   2.897072   3.2918770   3.722365   4.768387 2870.263
k     485.9463804 189.54584919 1.116032e+04 108.837253 172.298576 224.9516964 302.439180 829.057592 3466.768
           Rhat
alpha 0.9995636
beta  1.0000370
k     1.0001177

I still get some divergences when the data are consistent with almost no change (I guess this is the max biting me will think if I can somehow get rid of it), so it is not perfect, but I think it is better than the previous parametrization…

Well, enough exploration for today… We’ll see if that actually helps you, but I learned a trick or two on the way :-)

4 Likes