Long warm-up for change point model

Dear forum,

I’m trying to run a change point model, which is heavily inspired by this thread.
My model differs by having multiple change points along X and an additional linear relationship with predictor X.

However, there are convergence issues and I need I large number of warm-ups in order for all chains to sample the model parameters from the highest log density of the posterior (lp__).
For the toy example below there are 23,000 warmups needed, which seems not typical for stan.

The change points in the model are constrained so that the shift #2 is older (i.e. greater X)) than shift #1 by summing both parameters a and using a strictly positive prior for change point #2.

library(cmdstanr)
library(brms)
library(bayesplot)

# The data
set.seed(73)
X <- c(1:125)
Y <- c(rnorm(5, 0, 0.3), # Period 1
       rnorm(5, 1.5, 0.3), # Period 2
       rnorm(10, 0, 0.3), # Period 3
       rnorm(5, 1.0, 0.3), # Period 4 
       rnorm(25, -1.2, 0.3), # Period 5
       rnorm(50, 0.1, 0.3), # Period 6
       rnorm(25, 1.1, 0.3)) # Period 7
Y <- Y - 0.01 * X
XY <- data.frame(X = X, Y = Y)

# The model
bform_switch <- brms::bf(Y ~ bX * X +
                           b1 + # Younger than Shift1
                           b2 * inv_logit((X - o1) * 5) * inv_logit((o2 - X) * 5) + # Between Shift 1 and 2
                           b3 * inv_logit((X - o2) * 5) * inv_logit((o3 - X) * 5) + # Between Shift 2 and 3
                           b4 * inv_logit((X - o3) * 5) * inv_logit((o4 - X) * 5) + # Between Shift 3 and 4
                           b5 * inv_logit((X - o4) * 5) * inv_logit((o5 - X) * 5) + # Between Shift 4 and 5
                           b6 * inv_logit((X - o5) * 5) * inv_logit((o6 - X) * 5) + # Between Shift 5 and 6
                           b7 * inv_logit((X - o6) * 5), # Older than Shift6
                         # keep omega within the range of predictor
                         brms::nlf(o1 ~ inv_logit(a1) * 125),
                         brms::nlf(o2 ~ inv_logit(a1 + a2) * 125),
                         brms::nlf(o3 ~ inv_logit(a1 + a2 + a3) * 125),
                         brms::nlf(o4 ~ inv_logit(a1 + a2 + a3 + a4) * 125),
                         brms::nlf(o5 ~ inv_logit(a1 + a2 + a3 + a4 + a5) * 125),
                         brms::nlf(o6 ~ inv_logit(a1 + a2 + a3 + a4 + a5 + a6) * 125),
                         bX + b1 + b2 + b3 + b4 + b5 + b6 + b7 + a1 + a2 + a3 + a4 + a5 + a6 ~ 1,
                         nl = TRUE)

# Priors
bprior <- c(prior(normal(0, 1), nlpar = "bX"),
            prior(normal(0, 2), nlpar = "b1"),
            prior(normal(0, 2), nlpar = "b2"),
            prior(normal(0, 2), nlpar = "b3"),
            prior(normal(0, 2), nlpar = "b4"),
            prior(normal(0, 2), nlpar = "b5"),
            prior(normal(0, 2), nlpar = "b6"),
            prior(normal(0, 2), nlpar = "b7"),
            prior(normal(-1, 2), nlpar = "a1"),
            prior(exponential(2), nlpar = "a2", lb = 0),
            prior(exponential(2), nlpar = "a3", lb = 0),
            prior(exponential(2), nlpar = "a4", lb = 0),
            prior(exponential(2), nlpar = "a5", lb = 0),
            prior(exponential(2), nlpar = "a6", lb = 0))

# Sampling
fit_s <- brms::brm(bform_switch, data = XY, prior = bprior,
                   cores = 3, chains = 3,
                   iter = 2000, warmup = 1000, thin = 1, refresh = 500,
                   seed = 13,
                   backend = "cmdstanr",
                   control = list(adapt_delta = 0.999, max_treedepth = 20))

With default warmup and iterations, trace of lp__ looks like this:

mcmc_trace(fit_s, "lp__")

Lp_trace_2000

With 23,000 warmups and 25,000 iterations, trace of lp__ looks better:
Lp_trace25000

Conditional effect plot is as expected:

plot(brms::conditional_effects(fit_s), points = TRUE)

CondEff

What I have tried so far to improve convergence:

  • increase adapt_delta and treedepth: Although there are no warnings on divergence etc., less iterations are needed but not much
  • provide initial values via inits: Initial values closer to the known shift points help because the chains are in a more likely part of the parameter space. But this feels like cheating.
  • Tighter priors: Does not help much.
  • Making shifts and linear predictor “independent”: I scaled X XY <- data.frame(X = X, Xscale = scale(X), Y = Y) and specified bX * Xscale in my model bform_switch.

Thanks for any advise how to improve convergence!

Operating System: Manjaro 21
Interface Version: brms 2.16.1
Compiler/Toolkit: cmdstan 2.27