Piecewise Linear Mixed Models With a Random Change Point

I’m wondering if it’s possible to implement a piecewise linear mixed model with a random change point in brms? This paper implements the model in Stan:

Bayesian Piecewise Linear Mixed Models With a Random Change Point

  • Operating System: Windows
  • brms Version: 2.4.0

Thanks!

1 Like

I haven’t read the paper in detail, but you may try the following model. I built it rather ad-hoc so not sure if that’s the best solution we have in brms. Please also note that the priors are just to get the model running with this non-sensible simulated data.

library(brms)

bform <- bf(
  y ~ b0 + b1 * (age - omega) * step(omega - age) + 
    b2 * (age - omega) * step(age - omega),
  b0 + b1 + b2 + alpha ~ 1 + (1|person),
  # to keep omega within the age range of 0 to 10
  nlf(omega ~ inv_logit(alpha) * 10),
  nl = TRUE
)

df <- data.frame(
  y = rnorm(330),
  age = rep(0:10, 30),
  person = rep(1:30, each = 11)
)

bprior <- prior(normal(0, 3), nlpar = "b0") +
  prior(normal(0, 3), nlpar = "b1") +
  prior(normal(0, 3), nlpar = "b2") +
  prior(normal(0, 3), nlpar = "alpha")

make_stancode(bform, data = df, prior = bprior)

fit <- brm(bform, data = df, prior = bprior, chains = 1)
summary(fit)

# you need the github version of brms for this to run
marginal_effects(fit)
1 Like

This looks cool. I’ll try it out. Thanks, Paul

Nice piece of code. I am intrigued about what the step function is doing in this instance- can anyone shed any light on it?

Similarly, how would one develop the code above to allow for a quadratic relationship in either/both slopes?

thanks in advance,
Jack

Stumbled upon this when I was thinking if there’s anything brms couldn’t do.

Actually I learned quite a bit of syntax that was completely new to me, e.g. add up response variables is just a shorthand way of giving them the same link functions.

for the question on step function, I tried the RStudio help (which direct to something related to AIC), and googling “r formula step”, “brms formula step”. No luck.

Then it dawned on me that brms relies on stan to do the heavy-lifting of HMC, so maybe it is a stan function? then I actually tried to read the code generated by the make_stancode function as suggested by @paul.buerkner at #2 post. And there it was!

So I go back to google “stan function step”, and here’s the result (you have to scroll down a bit).
step

And it does a simple thing, if the parameter is positive, it returns 1, otherwise 0.

1 Like

Dear Paul,
thank you for outlining the model formula above, which is really helpful! I have a rich dataset with a mean of 100 repeated observations per individual and reason to suspect that there is actually more than one random change point. Is it possible to add a third slope (b3) and a second random change point (between b2 and b3), and how would one go about it? Particularly, I am not sure how to add a second alpha term.
Thanks!

I am not sure exactly, but one way to do it would be to introduce a new omega (and new alpha if needed) and to add them to the model. Then you need to make sure somehow that one of the omegas is always larger then the other to identify the model.

In any case, it may make sense that you first write down the mathematical model you want to estimate so that we have a basis for thinking of how to write that down in brms or Stan.

Thanks for the quick answer. I see if I can come up with mathematical model.

@paul.buerkner, there’s a question for you at the bottom of this post :-)

I tried working more on this and made some progress. I wanted to extend @paul.buerkner s code by:

  • modeling intercept and slopes at a population-level so that only change points vary between participants.
  • obtaining more interpretable parameter estimates.

Simulated data

Here is some data with common slopes but varying-by-person change points:

# Data parameters
intercept = -2
slope1 = 1.0
slope2 = -0.5
breakpoints = runif(30, 2, 7)

# Predictors
df <- data.frame(
  age = rep(seq(0, 10, by=0.5), 30),
  person = factor(rep(1:30, each = 21))
) 

# Response with per-person break point
df$y = intercept + ifelse(
    df$age <= changepoints[df$person],
    yes = slope1 * df$age,  # before change
    no = slope1 * changepoints[df$person] +
      slope2 * (df$age - changepoints[df$person])  # after change
  )
df$y = rnorm(nrow(df), mean = df$y, sd = 0.5)

# Plot it
ggplot(df, aes(y=y, x=age, color=person)) + 
  geom_point() + 
  geom_line() + 
  theme(legend.position = 'none')

Specify model

We specify a slope-change only model (the intercept of slope2 at the the changepoint (change) is slope1 * change). The prior on change is important.

# The model
bform3 <- bf(
  y ~ Intercept + slope1 * age * step(change - age) +  # Section 1
    (slope1 * change + slope2 * (age - change)) * step(age - change),  # Section 2
  Intercept + slope1 + slope2 ~ 1,  # Fixed intercept and slopes
  change ~ 1 + (1|person),  # Per-person changepoints around pop mean
  nl = TRUE
)

# Priors
bprior3 <- prior(normal(0, 5), nlpar = "Intercept") +
  prior(normal(0, 2), nlpar = "slope1") +
  prior(normal(0, 2), nlpar = "slope2") +
  prior(uniform(0, 10), nlpar = "change")  # Within observed range

# Initial values
inits3 = list(list(
  slope1 = slope1,
  slope2 = slope2,
  Intercept = intercept
))

# Fit it!
fit3 <- brm(bform3, data = df, prior = bprior3, chains = 1, inits = inits3)

Inspect fit

The parameter recovery looks (very!) promising:

Group-Level Effects: 
~person (Number of levels: 30) 
                     Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(change_Intercept)     1.46      0.19     1.16     1.90 1.00      111      236

Population-Level Effects: 
                    Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept_Intercept    -2.02      0.05    -2.13    -1.92 1.01      778     1006
slope1_Intercept        1.02      0.02     0.98     1.06 1.01      828     1061
slope2_Intercept       -0.47      0.01    -0.50    -0.44 1.00     1079     1434
change_Intercept        4.31      0.27     3.74     4.85 1.01      160      259
marginal_effects(fit3)

image

The individual change points are in the correct magnitude (< |(7-2)/2| from the mean change point), but they seem not to capture the values used for simulation:

plot(breakpoints, ranef(fit3)$person[,1,'change_Intercept'])

image

Has the order of person been changed in ranef(fit3) or is there another simple explanation of this, @paul.buerkner?

2 Likes

I hope it is just the order. You can see the ranef order by looking at its column names. Since you defined person as a factor, the order might be 1, 10, 11, …, 19, 2, 20, etc. which might explain the problems you see.

This is not a bug but related to how R handles factors, but not sure if it is desriable behavior either.

Hi Paul,

Is using inv_logit here an identification requirement? I ask because I’d like to use a similar model, though my running variable runs from -100 to +100 and, thus, it doesn’t make much sense to constrain omega to be positive.

If it’s a constraint, I can rescale the data. I was just wondering why you’d used it here.

Jack

You could constrain omega to any range of values by scaling the inv_logit appropriately. In the example above I scaled it from 0 to 10 but you can also scale it from -100 to 100.

I’m also interested in fitting a linear mixed model with a random change point. However, my dataset has an additional categorical predictor, “group”, with 3 levels. I’m trying to incorporate a fixed effect of group into @paul.buerkner’s example code above, but am currently at a bit of a loss.

For reference, if we ignore the change point for a moment, my linear mixed model formula would look as follows:
y ~ age * group + (1 + age | person)

However, I have reason to believe that my data would be better fit by introducing a change point. Ideally, I’d like to estimate the effect of my categorical predictor “group” on the change point as well. As a start, I first tried adapting the code above by just including a main effect of “group” as follows:

bform <- bf(
    # intercept and main effect of group
    y ~ b0 + b1 * group +
        # pre-change slope
        b2 * (age - omega) * step(omega - age) +
        # post-change slope
        b3 * (age - omega) * step(age - omega),
    # intercept, slopes and change point varying by person
    b0 + b2 + b3 + omega ~ 1 + (1 | person),
    # fixed effect of group
    b1 ~ 1,
    nl = TRUE
)

However, this gave me the following error message:

Error: Factors with more than two levels are not allowed as covariates.

I saw that this issue came up previously, but I couldn’t really follow the solution in that thread. It seems I should incorporate the non-linear parameters in a linear formula, but I’m not sure how to go about this.

Thanks in advance for any advice!

I think you can also treat the beta parameters for the slope as linear models and add your group as predictors.

Then I think you’d specify the model as follows:

bform <- 
    bf(
    y ~ b0 + b1 * (age - omega) * step(omega - age) +
    b2 * (age - omega) * step(age - omega),
    b0 + b1 + b2 + omega ~ 1 + group + (1 | person),
    nl = TRUE
)

Keep in mind that this is just off the top of my head and I haven’t fit a similar model myself. Others might know whether or not it makes sense.

P.S. The error you got just means that you need to recode your factor variable into a set of dummy variables.

1 Like

Hi @paul.buerkner ,

I’m trying to build on the above thread to implement a “switch model” as described in the supplementary information (page 2) of this manuscript:

It’s very similar to a piecewise linear model, only it has no slopes, and it allows for the variance to change before and after the “switch” (break) point.

I think I got the model structure right (?) but the fitting is on the slow-ish end and I get some convergence issues. Moreover, the effective sample size is quite low although the mean effects are correctly estimated. I was wondering if you have any ideas on how to improve this implementation to speed up the code and eliminate the convergence issues? Reprex below.

Thanks!

library(tidyverse)
library(brms)
rstan::rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

set.seed(10)
b0 <- -1
b1 <- 3
b2 <- -4
s1 <- 0.5
s2 <- 1
omega <- 30
group_error <- rnorm(50, 0, 0.01)
names(group_error) <- as.character(1:50)
df <- data.frame(predictor = seq(0.2, 1e2, 0.2)) %>%
  dplyr::mutate(group_id = rep(1:50, each = 10) %>%
                  sample %>%
                  as.factor,
                mean_error = ifelse(predictor < omega,
                                    rnorm(sum(predictor < omega), 0, s1),
                                    rnorm(sum(predictor >= omega), 0, s2)),
                response = b0 + ifelse(predictor - omega < 0, b1, b2) +
                  mean_error + group_error[group_id])

ggplot(data = df, mapping = aes(y = response, x = predictor)) +
  geom_point()

bform_switch <- brms::bf(response ~ b0 + b1 * step(omega - predictor) +
                            b2 * step(predictor - omega),
                         # keep omega within the range of predictor
                         brms::nlf(omega ~ inv_logit(alpha) * 1e2),
                         # allow error to change with switch
                         brms::nlf(sigma ~ s1 * step(omega - predictor) +
                                     s2 * step(predictor - omega)),
                         b0 ~ 1 + (1 | group_id),
                         s1 + s2 + b1 + b2 + alpha ~ 1,
                         nl = TRUE)

bprior <- prior(normal(0, 2), nlpar = "b0") +
          prior(normal(0, 2), nlpar = "b1") +
          prior(normal(0, 2), nlpar = "b2") +
          prior(normal(0, 1), nlpar = "alpha") +
          prior(normal(0, 1), nlpar = "s1") +
          prior(normal(0, 1), nlpar = "s2")

fit_s <- brms::brm(bform_switch, data = df, prior = bprior)

These are the returned warning messages I got:

1: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded 
2: Examine the pairs() plot to diagnose sampling problems
 
3: The largest R-hat is 1.72, indicating chains have not mixed.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#r-hat 
4: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#bulk-ess 
5: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#tail-ess 

The outcome seems to capture well the simulated estimates (note that values of parameters s1 and s2 are estimated on the log scale, so expect s1 \approx log(0.5) and s2 \approx log(1)).

> fit_s
 Family: gaussian 
  Links: mu = identity; sigma = log 
Formula: response ~ b0 + b1 * step(omega - predictor) + b2 * step(predictor - omega) 
         omega ~ inv_logit(alpha) * 100
         sigma ~ s1 * step(omega - predictor) + s2 * step(predictor - omega)
         b0 ~ 1 + (1 | group_id)
         s1 ~ 1
         s2 ~ 1
         b1 ~ 1
         b2 ~ 1
         alpha ~ 1
   Data: df (Number of observations: 500) 
Samples: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup samples = 4000

Group-Level Effects: 
~group_id (Number of levels: 50) 
                 Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(b0_Intercept)     0.04      0.04     0.00     0.14 1.61        7       12

Population-Level Effects: 
                Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
b0_Intercept       -1.61      0.52    -2.40    -0.37 1.29       12       25
s1_Intercept       -0.60      0.06    -0.71    -0.48 1.01      463      860
s2_Intercept       -0.02      0.04    -0.09     0.06 1.01      368      905
b1_Intercept        3.62      0.52     2.36     4.41 1.30       11       25
b2_Intercept       -3.43      0.52    -4.70    -2.65 1.30       11       24
alpha_Intercept    -0.85      0.00    -0.86    -0.85 1.00      961     2520

Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Warning message:
Parts of the model have not converged (some Rhats are > 1.05). Be careful when analysing the results! We recommend running more iterations and/or setting stronger priors. 

Including the before-after changes in residual error:

plot(brms::conditional_effects(fit_s, method = "posterior_predict"),
     points = TRUE)

For the sake of transparency, this is output of my sessionInfo():

R version 4.0.2 (2020-06-22)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS Catalina 10.15.7

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib

locale:
[1] en_AU.UTF-8/en_AU.UTF-8/en_AU.UTF-8/C/en_AU.UTF-8/en_AU.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] brms_2.14.4     Rcpp_1.0.5      forcats_0.5.0   stringr_1.4.0   dplyr_1.0.2     purrr_0.3.4     readr_1.4.0     tidyr_1.1.2     tibble_3.0.4    ggplot2_3.3.2   tidyverse_1.3.0

loaded via a namespace (and not attached):
  [1] minqa_1.2.4          colorspace_2.0-0     ellipsis_0.3.1       ggridges_0.5.2       rsconnect_0.8.16     estimability_1.3     markdown_1.1         base64enc_0.1-3      fs_1.5.0            
 [10] rstudioapi_0.13      farver_2.0.3         rstan_2.21.3         DT_0.16              fansi_0.4.1          mvtnorm_1.1-1        lubridate_1.7.9.2    xml2_1.3.2           codetools_0.2-16    
 [19] bridgesampling_1.0-0 splines_4.0.2        shinythemes_1.1.2    bayesplot_1.7.2      projpred_2.0.2       jsonlite_1.7.1       nloptr_1.2.2.2       broom_0.7.2          dbplyr_2.0.0        
 [28] shiny_1.5.0          compiler_4.0.2       httr_1.4.2           emmeans_1.5.0        backports_1.2.0      assertthat_0.2.1     Matrix_1.2-18        fastmap_1.0.1        cli_2.1.0           
 [37] later_1.1.0.1        prettyunits_1.1.1    htmltools_0.5.0      tools_4.0.2          igraph_1.2.6         coda_0.19-4          gtable_0.3.0         glue_1.4.2           reshape2_1.4.4      
 [46] V8_3.4.0             cellranger_1.1.0     vctrs_0.3.5          nlme_3.1-148         crosstalk_1.1.0.1    ps_1.4.0             lme4_1.1-25          rvest_0.3.6          mime_0.9            
 [55] miniUI_0.1.1.1       lifecycle_0.2.0      gtools_3.8.2         statmod_1.4.35       MASS_7.3-51.6        zoo_1.8-8            scales_1.1.1         colourpicker_1.1.0   hms_0.5.3           
 [64] promises_1.1.1       Brobdingnag_1.2-6    parallel_4.0.2       inline_0.3.16        shinystan_2.5.0      curl_4.3             gamm4_0.2-6          gridExtra_2.3        StanHeaders_2.21.0-6
 [73] loo_2.3.1            stringi_1.5.3        dygraphs_1.1.1.6     boot_1.3-25          pkgbuild_1.1.0       rlang_0.4.8          pkgconfig_2.0.3      matrixStats_0.57.0   lattice_0.20-41     
 [82] labeling_0.4.2       rstantools_2.1.1     htmlwidgets_1.5.2    processx_3.4.4       tidyselect_1.1.0     plyr_1.8.6           magrittr_2.0.1       R6_2.5.0             generics_0.1.0      
 [91] DBI_1.1.0            pillar_1.4.6         haven_2.3.1          withr_2.3.0          mgcv_1.8-31          xts_0.12.1           abind_1.4-5          modelr_0.1.8         crayon_1.3.4        
[100] grid_4.0.2           readxl_1.3.1         callr_3.5.1          threejs_0.3.3        reprex_0.3.0         digest_0.6.27        xtable_1.8-4         httpuv_1.5.4         RcppParallel_5.0.2  
[109] stats4_4.0.2         munsell_0.5.0        shinyjs_2.0.0       

I find it most likely that the step calls are the issue as they introduce discontinuities into the density and its gradient. You can usually achieve a similar effect while staying within the realm of continous things by replacing step(x) with inv_logit(x * n) where n is a suitable constant - the larger n the closer to the step function you get, but the more likely you are to enter sampling issues.

Does that make sense?

1 Like

That did the trick, thanks @martinmodrak!
I just substituted step(predictor - omega) with inv_logit((predictor - omega) * 5). No convergence issues, all ESS looking good.

1 Like