Piecewise Linear Mixed Models With a Random Change Point

Hi @paul.buerkner ,

I’m trying to build on the above thread to implement a “switch model” as described in the supplementary information (page 2) of this manuscript:

It’s very similar to a piecewise linear model, only it has no slopes, and it allows for the variance to change before and after the “switch” (break) point.

I think I got the model structure right (?) but the fitting is on the slow-ish end and I get some convergence issues. Moreover, the effective sample size is quite low although the mean effects are correctly estimated. I was wondering if you have any ideas on how to improve this implementation to speed up the code and eliminate the convergence issues? Reprex below.

Thanks!

library(tidyverse)
library(brms)
rstan::rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

set.seed(10)
b0 <- -1
b1 <- 3
b2 <- -4
s1 <- 0.5
s2 <- 1
omega <- 30
group_error <- rnorm(50, 0, 0.01)
names(group_error) <- as.character(1:50)
df <- data.frame(predictor = seq(0.2, 1e2, 0.2)) %>%
  dplyr::mutate(group_id = rep(1:50, each = 10) %>%
                  sample %>%
                  as.factor,
                mean_error = ifelse(predictor < omega,
                                    rnorm(sum(predictor < omega), 0, s1),
                                    rnorm(sum(predictor >= omega), 0, s2)),
                response = b0 + ifelse(predictor - omega < 0, b1, b2) +
                  mean_error + group_error[group_id])

ggplot(data = df, mapping = aes(y = response, x = predictor)) +
  geom_point()

bform_switch <- brms::bf(response ~ b0 + b1 * step(omega - predictor) +
                            b2 * step(predictor - omega),
                         # keep omega within the range of predictor
                         brms::nlf(omega ~ inv_logit(alpha) * 1e2),
                         # allow error to change with switch
                         brms::nlf(sigma ~ s1 * step(omega - predictor) +
                                     s2 * step(predictor - omega)),
                         b0 ~ 1 + (1 | group_id),
                         s1 + s2 + b1 + b2 + alpha ~ 1,
                         nl = TRUE)

bprior <- prior(normal(0, 2), nlpar = "b0") +
          prior(normal(0, 2), nlpar = "b1") +
          prior(normal(0, 2), nlpar = "b2") +
          prior(normal(0, 1), nlpar = "alpha") +
          prior(normal(0, 1), nlpar = "s1") +
          prior(normal(0, 1), nlpar = "s2")

fit_s <- brms::brm(bform_switch, data = df, prior = bprior)

These are the returned warning messages I got:

1: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded 
2: Examine the pairs() plot to diagnose sampling problems
 
3: The largest R-hat is 1.72, indicating chains have not mixed.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#r-hat 
4: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#bulk-ess 
5: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#tail-ess 

The outcome seems to capture well the simulated estimates (note that values of parameters s1 and s2 are estimated on the log scale, so expect s1 \approx log(0.5) and s2 \approx log(1)).

> fit_s
 Family: gaussian 
  Links: mu = identity; sigma = log 
Formula: response ~ b0 + b1 * step(omega - predictor) + b2 * step(predictor - omega) 
         omega ~ inv_logit(alpha) * 100
         sigma ~ s1 * step(omega - predictor) + s2 * step(predictor - omega)
         b0 ~ 1 + (1 | group_id)
         s1 ~ 1
         s2 ~ 1
         b1 ~ 1
         b2 ~ 1
         alpha ~ 1
   Data: df (Number of observations: 500) 
Samples: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup samples = 4000

Group-Level Effects: 
~group_id (Number of levels: 50) 
                 Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(b0_Intercept)     0.04      0.04     0.00     0.14 1.61        7       12

Population-Level Effects: 
                Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
b0_Intercept       -1.61      0.52    -2.40    -0.37 1.29       12       25
s1_Intercept       -0.60      0.06    -0.71    -0.48 1.01      463      860
s2_Intercept       -0.02      0.04    -0.09     0.06 1.01      368      905
b1_Intercept        3.62      0.52     2.36     4.41 1.30       11       25
b2_Intercept       -3.43      0.52    -4.70    -2.65 1.30       11       24
alpha_Intercept    -0.85      0.00    -0.86    -0.85 1.00      961     2520

Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Warning message:
Parts of the model have not converged (some Rhats are > 1.05). Be careful when analysing the results! We recommend running more iterations and/or setting stronger priors. 

Including the before-after changes in residual error:

plot(brms::conditional_effects(fit_s, method = "posterior_predict"),
     points = TRUE)

For the sake of transparency, this is output of my sessionInfo():

R version 4.0.2 (2020-06-22)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS Catalina 10.15.7

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib

locale:
[1] en_AU.UTF-8/en_AU.UTF-8/en_AU.UTF-8/C/en_AU.UTF-8/en_AU.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] brms_2.14.4     Rcpp_1.0.5      forcats_0.5.0   stringr_1.4.0   dplyr_1.0.2     purrr_0.3.4     readr_1.4.0     tidyr_1.1.2     tibble_3.0.4    ggplot2_3.3.2   tidyverse_1.3.0

loaded via a namespace (and not attached):
  [1] minqa_1.2.4          colorspace_2.0-0     ellipsis_0.3.1       ggridges_0.5.2       rsconnect_0.8.16     estimability_1.3     markdown_1.1         base64enc_0.1-3      fs_1.5.0            
 [10] rstudioapi_0.13      farver_2.0.3         rstan_2.21.3         DT_0.16              fansi_0.4.1          mvtnorm_1.1-1        lubridate_1.7.9.2    xml2_1.3.2           codetools_0.2-16    
 [19] bridgesampling_1.0-0 splines_4.0.2        shinythemes_1.1.2    bayesplot_1.7.2      projpred_2.0.2       jsonlite_1.7.1       nloptr_1.2.2.2       broom_0.7.2          dbplyr_2.0.0        
 [28] shiny_1.5.0          compiler_4.0.2       httr_1.4.2           emmeans_1.5.0        backports_1.2.0      assertthat_0.2.1     Matrix_1.2-18        fastmap_1.0.1        cli_2.1.0           
 [37] later_1.1.0.1        prettyunits_1.1.1    htmltools_0.5.0      tools_4.0.2          igraph_1.2.6         coda_0.19-4          gtable_0.3.0         glue_1.4.2           reshape2_1.4.4      
 [46] V8_3.4.0             cellranger_1.1.0     vctrs_0.3.5          nlme_3.1-148         crosstalk_1.1.0.1    ps_1.4.0             lme4_1.1-25          rvest_0.3.6          mime_0.9            
 [55] miniUI_0.1.1.1       lifecycle_0.2.0      gtools_3.8.2         statmod_1.4.35       MASS_7.3-51.6        zoo_1.8-8            scales_1.1.1         colourpicker_1.1.0   hms_0.5.3           
 [64] promises_1.1.1       Brobdingnag_1.2-6    parallel_4.0.2       inline_0.3.16        shinystan_2.5.0      curl_4.3             gamm4_0.2-6          gridExtra_2.3        StanHeaders_2.21.0-6
 [73] loo_2.3.1            stringi_1.5.3        dygraphs_1.1.1.6     boot_1.3-25          pkgbuild_1.1.0       rlang_0.4.8          pkgconfig_2.0.3      matrixStats_0.57.0   lattice_0.20-41     
 [82] labeling_0.4.2       rstantools_2.1.1     htmlwidgets_1.5.2    processx_3.4.4       tidyselect_1.1.0     plyr_1.8.6           magrittr_2.0.1       R6_2.5.0             generics_0.1.0      
 [91] DBI_1.1.0            pillar_1.4.6         haven_2.3.1          withr_2.3.0          mgcv_1.8-31          xts_0.12.1           abind_1.4-5          modelr_0.1.8         crayon_1.3.4        
[100] grid_4.0.2           readxl_1.3.1         callr_3.5.1          threejs_0.3.3        reprex_0.3.0         digest_0.6.27        xtable_1.8-4         httpuv_1.5.4         RcppParallel_5.0.2  
[109] stats4_4.0.2         munsell_0.5.0        shinyjs_2.0.0       
1 Like