Trunc() and conditional_effects()

  • Operating System: MacOS
  • brms Version: 2.14.4

When I fit a model to truncated data using trunc(), the parameter summary does a good job capturing the data-generating values. However, the fitted lines from conditional_effects() and fitted() do not seem to be based on these values. What am I missing?

Example:

library(tidyverse)
library(brms)

# populaiton n
n <- 1e6

# set the population parameters for y based on a simple linear model
b0 <- 0
b1 <- 0.5
sigma <- 1

# simulate
set.seed(1)
d <-
  # the predictor x is a standard normal
  tibble(x = rnorm(n, mean = 0, sd = 1)) %>% 
  # compute y
  mutate(y = rnorm(n, mean = b0 + b1 * x, sd = 1))

Now truncate the population by discarding any y values less than or equal to -1.

d_t <-
  d %>% 
  filter(y > -1)

About 81% of the original cases remain. Now take a random sample of n = 1{,}000 from the truncated population.

set.seed(1)

d_t_1e3 <- 
  d_t %>% 
  sample_n(size = 1e3)

Fit a model on the truncated data using the y | trunc(lb = -1) syntax in the formula argument.

m1 <-
  brm(data = d_t_1e3,
      y | trunc(lb = -1) ~ 1 + x,
      cores = 4, chains = 4)

The model summary suggests the model did a pretty okay job capturing the true values from the simple linear model that generated the original data,

\begin{align*} y_i & \sim \operatorname{N}(\mu_i, \sigma = 1) \\ \mu_i & = 0 + 0.5\ x_i. \end{align*}
print(m1)
 Family: gaussian 
  Links: mu = identity; sigma = identity 
Formula: y | trunc(lb = -1) ~ 1 + x 
   Data: d_t_1e3 (Number of observations: 1000) 
Samples: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup samples = 4000

Population-Level Effects: 
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept    -0.11      0.08    -0.27     0.02 1.00     1429     1620
x             0.52      0.05     0.43     0.63 1.00     1600     2122

Family Specific Parameters: 
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma     1.09      0.05     1.00     1.18 1.00     1464     1759

Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

What I don’t understand is why the fitted line in conditional_effects() suggests a non-linear model.

conditional_effects(m1) %>% 
  plot(points = T, point_args = list(size = 1/4))

I get that this non-linear line makes sense in the context of the observed truncated data. I don’t understand what function is being used by conditional_effects() to achieve this non-linear line. Though it’s not shown here, the same behavior occurs when I use fitted().

fitted() and thus conditional_effects() using fitted() behind the scenes, always returns the mean of the response distribution. This includes correcting the mean for the truncation of the response distribution (for families where we know the mean of the truncated distributions; for other families fitted() fails). This, I think, explains why we see non-linearity as the predictions move further away from the lower boundary of -1 as x increases.

Regarding the way fitted() corrects for the mean of the truncated distribution, is there a way a user might use the model parameters, such as \beta_0 and \beta_1 above, to recreate that behavior by hand?

Check out the functions towards the end of brms/posterior_epred.R at master · paul-buerkner/brms · GitHub

1 Like

Thanks @paul.buerkner.

For those interested, the relevant function in the link appears to be posterior_epred_trunc_gaussian(). Here’s how I used elements from that function to replicate the behavior of fitted() by working directly with the posterior samples.

# extract the posterior samples
posterior_samples(m1) %>% 
  # wrangle
  expand(nesting(b_Intercept, b_x, sigma),
         x = seq(from = min(d_t_1e3$x), to = max(d_t_1e3$x), length.out = 30)) %>% 
  mutate(mu = b_Intercept + b_x * x) %>% 
  mutate(lb = -1,
         ub = Inf) %>% 
  mutate(zlb = (lb - mu) / sigma,
         zub = (ub - mu) / sigma) %>% 
  mutate(trunc_zmean = (dnorm(zlb) - dnorm(zub)) / (pnorm(zub) - pnorm(zlb))) %>% 
  mutate(trunc_mu = mu + trunc_zmean * sigma) %>% 
  group_by(x) %>% 
  tidybayes::mean_qi(trunc_mu) %>% 
  
  # plot!
  ggplot(aes(x = x)) +
  geom_point(data = d_t_1e3,
             aes(y = y),
             size = 1/4) +
  geom_ribbon(aes(ymin = .lower, ymax = .upper),
              alpha = 1/4) +
  geom_line(aes(y = trunc_mu),
            color = "blue")