Long warm-up for change point model

Dear forum,

I’m trying to run a change point model, which is heavily inspired by this thread.
My model differs by having multiple change points along X and an additional linear relationship with predictor X.

However, there are convergence issues and I need I large number of warm-ups in order for all chains to sample the model parameters from the highest log density of the posterior (lp__).
For the toy example below there are 23,000 warmups needed, which seems not typical for stan.

The change points in the model are constrained so that the shift #2 is older (i.e. greater X)) than shift #1 by summing both parameters a and using a strictly positive prior for change point #2.

library(cmdstanr)
library(brms)
library(bayesplot)

# The data
set.seed(73)
X <- c(1:125)
Y <- c(rnorm(5, 0, 0.3), # Period 1
       rnorm(5, 1.5, 0.3), # Period 2
       rnorm(10, 0, 0.3), # Period 3
       rnorm(5, 1.0, 0.3), # Period 4 
       rnorm(25, -1.2, 0.3), # Period 5
       rnorm(50, 0.1, 0.3), # Period 6
       rnorm(25, 1.1, 0.3)) # Period 7
Y <- Y - 0.01 * X
XY <- data.frame(X = X, Y = Y)

# The model
bform_switch <- brms::bf(Y ~ bX * X +
                           b1 + # Younger than Shift1
                           b2 * inv_logit((X - o1) * 5) * inv_logit((o2 - X) * 5) + # Between Shift 1 and 2
                           b3 * inv_logit((X - o2) * 5) * inv_logit((o3 - X) * 5) + # Between Shift 2 and 3
                           b4 * inv_logit((X - o3) * 5) * inv_logit((o4 - X) * 5) + # Between Shift 3 and 4
                           b5 * inv_logit((X - o4) * 5) * inv_logit((o5 - X) * 5) + # Between Shift 4 and 5
                           b6 * inv_logit((X - o5) * 5) * inv_logit((o6 - X) * 5) + # Between Shift 5 and 6
                           b7 * inv_logit((X - o6) * 5), # Older than Shift6
                         # keep omega within the range of predictor
                         brms::nlf(o1 ~ inv_logit(a1) * 125),
                         brms::nlf(o2 ~ inv_logit(a1 + a2) * 125),
                         brms::nlf(o3 ~ inv_logit(a1 + a2 + a3) * 125),
                         brms::nlf(o4 ~ inv_logit(a1 + a2 + a3 + a4) * 125),
                         brms::nlf(o5 ~ inv_logit(a1 + a2 + a3 + a4 + a5) * 125),
                         brms::nlf(o6 ~ inv_logit(a1 + a2 + a3 + a4 + a5 + a6) * 125),
                         bX + b1 + b2 + b3 + b4 + b5 + b6 + b7 + a1 + a2 + a3 + a4 + a5 + a6 ~ 1,
                         nl = TRUE)

# Priors
bprior <- c(prior(normal(0, 1), nlpar = "bX"),
            prior(normal(0, 2), nlpar = "b1"),
            prior(normal(0, 2), nlpar = "b2"),
            prior(normal(0, 2), nlpar = "b3"),
            prior(normal(0, 2), nlpar = "b4"),
            prior(normal(0, 2), nlpar = "b5"),
            prior(normal(0, 2), nlpar = "b6"),
            prior(normal(0, 2), nlpar = "b7"),
            prior(normal(-1, 2), nlpar = "a1"),
            prior(exponential(2), nlpar = "a2", lb = 0),
            prior(exponential(2), nlpar = "a3", lb = 0),
            prior(exponential(2), nlpar = "a4", lb = 0),
            prior(exponential(2), nlpar = "a5", lb = 0),
            prior(exponential(2), nlpar = "a6", lb = 0))

# Sampling
fit_s <- brms::brm(bform_switch, data = XY, prior = bprior,
                   cores = 3, chains = 3,
                   iter = 2000, warmup = 1000, thin = 1, refresh = 500,
                   seed = 13,
                   backend = "cmdstanr",
                   control = list(adapt_delta = 0.999, max_treedepth = 20))

With default warmup and iterations, trace of lp__ looks like this:

mcmc_trace(fit_s, "lp__")

Lp_trace_2000

With 23,000 warmups and 25,000 iterations, trace of lp__ looks better:
Lp_trace25000

Conditional effect plot is as expected:

plot(brms::conditional_effects(fit_s), points = TRUE)

CondEff

What I have tried so far to improve convergence:

  • increase adapt_delta and treedepth: Although there are no warnings on divergence etc., less iterations are needed but not much
  • provide initial values via inits: Initial values closer to the known shift points help because the chains are in a more likely part of the parameter space. But this feels like cheating.
  • Tighter priors: Does not help much.
  • Making shifts and linear predictor “independent”: I scaled X XY <- data.frame(X = X, Xscale = scale(X), Y = Y) and specified bX * Xscale in my model bform_switch.

Thanks for any advise how to improve convergence!

Operating System: Manjaro 21
Interface Version: brms 2.16.1
Compiler/Toolkit: cmdstan 2.27

1 Like

Hi,
it appears the posterior for your model + data has multiple separate modes. My best guess is that they arise when some of the change points are squashed very closely together or pushed out of the range of data completely - the model then finds the best fit with fewer changepoints and any small perturbation from this configuration results in a locally worse fit. You should be able to confirm this hypothesis by looking at how the distribution of change point locations differs between the chains in the case where you don’t see convergence.

However it also appears that given enough iterations, the chains are able to transition between the modes and finally find the highest lp mode, given a lot of iterations.

You may be able to resolve the issue partially by providing better initial values for the parameters (e.g. changepoints equidistant across the range of data), but I don’t think that will work 100% - posteriors with multiple modes are just challenging for any inference algorithm, including Stan.

Best of luck with your model!

1 Like

Thanks a lot for your thoughts and explanations!

I can confirm that in the case of non-convergence chains differ in their inferred change points.
Moreover, setting initial values closer to the known change points aids convergence.

Another trick that I have found is to scale X before using it as linear predictor.

XY <- data.frame(X = X, Xscale = scale(X), Y = Y)

# The model
bform_switch <- brms::bf(Y ~ bX * Xscale +
                           b1 + # Younger than Shift1
                           b2 * inv_logit((X - o1) * 5) * inv_logit((o2 - X) * 5) + # Between Shift 1 and 2
                           b3 * inv_logit((X - o2) * 5) * inv_logit((o3 - X) * 5) + # Between Shift 2 and 3
                           b4 * inv_logit((X - o3) * 5) * inv_logit((o4 - X) * 5) + # Between Shift 3 and 4
                           b5 * inv_logit((X - o4) * 5) * inv_logit((o5 - X) * 5) + # Between Shift 4 and 5
                           b6 * inv_logit((X - o5) * 5) * inv_logit((o6 - X) * 5) + # Between Shift 5 and 6
                           b7 * inv_logit((X - o6) * 5), # Older than Shift6
                         # keep omega within the range of predictor
                         brms::nlf(o1 ~ inv_logit(a1) * 125),
                         brms::nlf(o2 ~ inv_logit(a1 + a2) * 125),
                         brms::nlf(o3 ~ inv_logit(a1 + a2 + a3) * 125),
                         brms::nlf(o4 ~ inv_logit(a1 + a2 + a3 + a4) * 125),
                         brms::nlf(o5 ~ inv_logit(a1 + a2 + a3 + a4 + a5) * 125),
                         brms::nlf(o6 ~ inv_logit(a1 + a2 + a3 + a4 + a5 + a6) * 125),
                         bX + b1 + b2 + b3 + b4 + b5 + b6 + b7 + a1 + a2 + a3 + a4 + a5 + a6 ~ 1,
                         nl = TRUE)
1 Like