Dear forum,
I’m trying to run a change point model, which is heavily inspired by this thread.
My model differs by having multiple change points along X and an additional linear relationship with predictor X.
However, there are convergence issues and I need I large number of warm-ups in order for all chains to sample the model parameters from the highest log density of the posterior (lp__).
For the toy example below there are 23,000 warmups needed, which seems not typical for stan.
The change points in the model are constrained so that the shift #2 is older (i.e. greater X)) than shift #1 by summing both parameters a
and using a strictly positive prior for change point #2.
# The data
X <- c(1:125)
Y <- c(rnorm(5, 0, 0.3), # Period 1
rnorm(5, 1.5, 0.3), # Period 2
rnorm(10, 0, 0.3), # Period 3
rnorm(5, 1.0, 0.3), # Period 4
rnorm(25, -1.2, 0.3), # Period 5
rnorm(50, 0.1, 0.3), # Period 6
rnorm(25, 1.1, 0.3)) # Period 7
Y <- Y - 0.01 * X
XY <- data.frame(X = X, Y = Y)
# The model
bform_switch <- brms::bf(Y ~ bX * X +
b1 + # Younger than Shift1
b2 * inv_logit((X - o1) * 5) * inv_logit((o2 - X) * 5) + # Between Shift 1 and 2
b3 * inv_logit((X - o2) * 5) * inv_logit((o3 - X) * 5) + # Between Shift 2 and 3
b4 * inv_logit((X - o3) * 5) * inv_logit((o4 - X) * 5) + # Between Shift 3 and 4
b5 * inv_logit((X - o4) * 5) * inv_logit((o5 - X) * 5) + # Between Shift 4 and 5
b6 * inv_logit((X - o5) * 5) * inv_logit((o6 - X) * 5) + # Between Shift 5 and 6
b7 * inv_logit((X - o6) * 5), # Older than Shift6
# keep omega within the range of predictor
brms::nlf(o1 ~ inv_logit(a1) * 125),
brms::nlf(o2 ~ inv_logit(a1 + a2) * 125),
brms::nlf(o3 ~ inv_logit(a1 + a2 + a3) * 125),
brms::nlf(o4 ~ inv_logit(a1 + a2 + a3 + a4) * 125),
brms::nlf(o5 ~ inv_logit(a1 + a2 + a3 + a4 + a5) * 125),
brms::nlf(o6 ~ inv_logit(a1 + a2 + a3 + a4 + a5 + a6) * 125),
bX + b1 + b2 + b3 + b4 + b5 + b6 + b7 + a1 + a2 + a3 + a4 + a5 + a6 ~ 1,
nl = TRUE)
# Priors
bprior <- c(prior(normal(0, 1), nlpar = "bX"),
prior(normal(0, 2), nlpar = "b1"),
prior(normal(0, 2), nlpar = "b2"),
prior(normal(0, 2), nlpar = "b3"),
prior(normal(0, 2), nlpar = "b4"),
prior(normal(0, 2), nlpar = "b5"),
prior(normal(0, 2), nlpar = "b6"),
prior(normal(0, 2), nlpar = "b7"),
prior(normal(-1, 2), nlpar = "a1"),
prior(exponential(2), nlpar = "a2", lb = 0),
prior(exponential(2), nlpar = "a3", lb = 0),
prior(exponential(2), nlpar = "a4", lb = 0),
prior(exponential(2), nlpar = "a5", lb = 0),
prior(exponential(2), nlpar = "a6", lb = 0))
# Sampling
fit_s <- brms::brm(bform_switch, data = XY, prior = bprior,
cores = 3, chains = 3,
iter = 2000, warmup = 1000, thin = 1, refresh = 500,
seed = 13,
backend = "cmdstanr",
control = list(adapt_delta = 0.999, max_treedepth = 20))
With default warmup and iterations, trace of lp__ looks like this:
mcmc_trace(fit_s, "lp__")
With 23,000 warmups and 25,000 iterations, trace of lp__ looks better:
Conditional effect plot is as expected:
plot(brms::conditional_effects(fit_s), points = TRUE)
What I have tried so far to improve convergence:
- increase adapt_delta and treedepth: Although there are no warnings on divergence etc., less iterations are needed but not much
- provide initial values via
: Initial values closer to the known shift points help because the chains are in a more likely part of the parameter space. But this feels like cheating. - Tighter priors: Does not help much.
- Making shifts and linear predictor “independent”: I scaled X
XY <- data.frame(X = X, Xscale = scale(X), Y = Y)
and specifiedbX * Xscale
in my modelbform_switch
Thanks for any advise how to improve convergence!
Operating System: Manjaro 21
Interface Version: brms 2.16.1
Compiler/Toolkit: cmdstan 2.27