Hi @paul.buerkner ,
I’m trying to build on the above thread to implement a “switch model” as described in the supplementary information (page 2) of this manuscript:
It’s very similar to a piecewise linear model, only it has no slopes, and it allows for the variance to change before and after the “switch” (break) point.
I think I got the model structure right (?) but the fitting is on the slow-ish end and I get some convergence issues. Moreover, the effective sample size is quite low although the mean effects are correctly estimated. I was wondering if you have any ideas on how to improve this implementation to speed up the code and eliminate the convergence issues? Reprex below.
Thanks!
library(tidyverse)
library(brms)
rstan::rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
set.seed(10)
b0 <- -1
b1 <- 3
b2 <- -4
s1 <- 0.5
s2 <- 1
omega <- 30
group_error <- rnorm(50, 0, 0.01)
names(group_error) <- as.character(1:50)
df <- data.frame(predictor = seq(0.2, 1e2, 0.2)) %>%
dplyr::mutate(group_id = rep(1:50, each = 10) %>%
sample %>%
as.factor,
mean_error = ifelse(predictor < omega,
rnorm(sum(predictor < omega), 0, s1),
rnorm(sum(predictor >= omega), 0, s2)),
response = b0 + ifelse(predictor - omega < 0, b1, b2) +
mean_error + group_error[group_id])
ggplot(data = df, mapping = aes(y = response, x = predictor)) +
geom_point()
bform_switch <- brms::bf(response ~ b0 + b1 * step(omega - predictor) +
b2 * step(predictor - omega),
# keep omega within the range of predictor
brms::nlf(omega ~ inv_logit(alpha) * 1e2),
# allow error to change with switch
brms::nlf(sigma ~ s1 * step(omega - predictor) +
s2 * step(predictor - omega)),
b0 ~ 1 + (1 | group_id),
s1 + s2 + b1 + b2 + alpha ~ 1,
nl = TRUE)
bprior <- prior(normal(0, 2), nlpar = "b0") +
prior(normal(0, 2), nlpar = "b1") +
prior(normal(0, 2), nlpar = "b2") +
prior(normal(0, 1), nlpar = "alpha") +
prior(normal(0, 1), nlpar = "s1") +
prior(normal(0, 1), nlpar = "s2")
fit_s <- brms::brm(bform_switch, data = df, prior = bprior)
These are the returned warning messages I got:
1: There were 4000 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
2: Examine the pairs() plot to diagnose sampling problems
3: The largest R-hat is 1.72, indicating chains have not mixed.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#r-hat
4: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#bulk-ess
5: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#tail-ess
The outcome seems to capture well the simulated estimates (note that values of parameters s1
and s2
are estimated on the log scale, so expect s1
\approx log(0.5) and s2
\approx log(1)).
> fit_s
Family: gaussian
Links: mu = identity; sigma = log
Formula: response ~ b0 + b1 * step(omega - predictor) + b2 * step(predictor - omega)
omega ~ inv_logit(alpha) * 100
sigma ~ s1 * step(omega - predictor) + s2 * step(predictor - omega)
b0 ~ 1 + (1 | group_id)
s1 ~ 1
s2 ~ 1
b1 ~ 1
b2 ~ 1
alpha ~ 1
Data: df (Number of observations: 500)
Samples: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
total post-warmup samples = 4000
Group-Level Effects:
~group_id (Number of levels: 50)
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(b0_Intercept) 0.04 0.04 0.00 0.14 1.61 7 12
Population-Level Effects:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
b0_Intercept -1.61 0.52 -2.40 -0.37 1.29 12 25
s1_Intercept -0.60 0.06 -0.71 -0.48 1.01 463 860
s2_Intercept -0.02 0.04 -0.09 0.06 1.01 368 905
b1_Intercept 3.62 0.52 2.36 4.41 1.30 11 25
b2_Intercept -3.43 0.52 -4.70 -2.65 1.30 11 24
alpha_Intercept -0.85 0.00 -0.86 -0.85 1.00 961 2520
Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Warning message:
Parts of the model have not converged (some Rhats are > 1.05). Be careful when analysing the results! We recommend running more iterations and/or setting stronger priors.
Including the before-after changes in residual error:
plot(brms::conditional_effects(fit_s, method = "posterior_predict"),
points = TRUE)
For the sake of transparency, this is output of my sessionInfo()
:
R version 4.0.2 (2020-06-22)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS Catalina 10.15.7
Matrix products: default
BLAS: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
locale:
[1] en_AU.UTF-8/en_AU.UTF-8/en_AU.UTF-8/C/en_AU.UTF-8/en_AU.UTF-8
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] brms_2.14.4 Rcpp_1.0.5 forcats_0.5.0 stringr_1.4.0 dplyr_1.0.2 purrr_0.3.4 readr_1.4.0 tidyr_1.1.2 tibble_3.0.4 ggplot2_3.3.2 tidyverse_1.3.0
loaded via a namespace (and not attached):
[1] minqa_1.2.4 colorspace_2.0-0 ellipsis_0.3.1 ggridges_0.5.2 rsconnect_0.8.16 estimability_1.3 markdown_1.1 base64enc_0.1-3 fs_1.5.0
[10] rstudioapi_0.13 farver_2.0.3 rstan_2.21.3 DT_0.16 fansi_0.4.1 mvtnorm_1.1-1 lubridate_1.7.9.2 xml2_1.3.2 codetools_0.2-16
[19] bridgesampling_1.0-0 splines_4.0.2 shinythemes_1.1.2 bayesplot_1.7.2 projpred_2.0.2 jsonlite_1.7.1 nloptr_1.2.2.2 broom_0.7.2 dbplyr_2.0.0
[28] shiny_1.5.0 compiler_4.0.2 httr_1.4.2 emmeans_1.5.0 backports_1.2.0 assertthat_0.2.1 Matrix_1.2-18 fastmap_1.0.1 cli_2.1.0
[37] later_1.1.0.1 prettyunits_1.1.1 htmltools_0.5.0 tools_4.0.2 igraph_1.2.6 coda_0.19-4 gtable_0.3.0 glue_1.4.2 reshape2_1.4.4
[46] V8_3.4.0 cellranger_1.1.0 vctrs_0.3.5 nlme_3.1-148 crosstalk_1.1.0.1 ps_1.4.0 lme4_1.1-25 rvest_0.3.6 mime_0.9
[55] miniUI_0.1.1.1 lifecycle_0.2.0 gtools_3.8.2 statmod_1.4.35 MASS_7.3-51.6 zoo_1.8-8 scales_1.1.1 colourpicker_1.1.0 hms_0.5.3
[64] promises_1.1.1 Brobdingnag_1.2-6 parallel_4.0.2 inline_0.3.16 shinystan_2.5.0 curl_4.3 gamm4_0.2-6 gridExtra_2.3 StanHeaders_2.21.0-6
[73] loo_2.3.1 stringi_1.5.3 dygraphs_1.1.1.6 boot_1.3-25 pkgbuild_1.1.0 rlang_0.4.8 pkgconfig_2.0.3 matrixStats_0.57.0 lattice_0.20-41
[82] labeling_0.4.2 rstantools_2.1.1 htmlwidgets_1.5.2 processx_3.4.4 tidyselect_1.1.0 plyr_1.8.6 magrittr_2.0.1 R6_2.5.0 generics_0.1.0
[91] DBI_1.1.0 pillar_1.4.6 haven_2.3.1 withr_2.3.0 mgcv_1.8-31 xts_0.12.1 abind_1.4-5 modelr_0.1.8 crayon_1.3.4
[100] grid_4.0.2 readxl_1.3.1 callr_3.5.1 threejs_0.3.3 reprex_0.3.0 digest_0.6.27 xtable_1.8-4 httpuv_1.5.4 RcppParallel_5.0.2
[109] stats4_4.0.2 munsell_0.5.0 shinyjs_2.0.0