Initialize MCMC with draws from variational inference in cmdstanr

I’m using cmdstanr and I want to initialize 4 chains of my MCMC with draws (or maybe means) from my variational approximation. There is a nice page on how to do this with cmdstanPy, but I surprisingly can’t find anything about this in cmdstanr. On top of this, converting draws to the appropriate init list format is a big pain manually, especially with vectors. I feel like this should be a common use case so I’m surprised not to find anything on it. Any help?

On of my case studies Bayesian workflow book - Birthdays, shows how to do that with values from optimization, and it should be relatively easy to do the same with values from a variational approximation.

3 Likes

Thanks, this is exactly what I needed! And wow, what a difference the initialization made. For reference, I was implementing a seemingly unrelated regression exactly as outlined in the reference manual. Typically 1 or 2 chains finished quickly (minutes) and 2 got stuck at suboptimal modes (hours to run, much lower log posterior). Here I initialize with a random sample from VI and all chains finish within minutes, great Rhats and neff/N.

draws_df <- vi_res$draws()
init1 <- lapply(1:4, function(chain) {
    sapply(c('alpha','beta','L_sigma','L_Omega'),
    function(variable) {as.numeric(subset(draws_df, variable=variable) %>% as_tibble() %>% sample_n(size=1))})
    }
)


stan_results <- sm_full$sample(
  data = data_list, chains = 4,
  init = init1,
  max_treedepth = 12, adapt_delta = 0.95,
  refresh=10,
)

It’s amazing how much of a difference this makes. It seems like there should be a way to warn users about chains which get stuck (low log posterior, small step size) during runtime. I’m often running into similar issues where some chains converge quickly and others don’t. I guess best practice is to come up with better (more informative priors) first?

2 Likes

It’s difficult to by default as low log density or small step size are relative. Maybe GitHub - flatironinstitute/mcmc-monitor: Monitor MCMC runs in the browser would be helpful for you.

More informative priors don’t change the way initialization is made, and if the initial value is somewhere where floating point accuracy is not sufficient, HMC may behave badly. Of course, if you can come up with better priors, you could also come up with better initial values. Pathfinder: Parallel quasi-Newton variational inference will be soon available in Stan, and will also make it easier to get better initial values.

It’s amazing how much of a difference this makes. It seems like there should be a way to warn users about chains which get stuck (low log posterior, small step size) during runtime. I’m often running into similar issues where some chains converge quickly and others don’t. I guess best practice is to come up with better (more informative priors) first?

I’ve noticed too that having a decent initialization takes a little extra work but can make a big difference, especially in models with ODE’s. I’ve also seen cases where initializing the mass matrix gives you a big speedup in the early stages of warmup. This is because in the first stage of warmup, the mass matrix is fixed to the identity so if there’s a parameter whose marginal posterior SD is something small, say 1e-3, then the leapfrog stepsize has to go down to compensate and you’ll get large treedepths for the first 75 or so warmup samples.

If I have a big model that I’ve fit previously that I want to refit now that I’ve gotten new additional data, I’ll try to initialize the parameters and the mass matrix using the old fit.

Lastly, I’ve found for certain complicated models the parameters found by VI can be pretty far from what I get with NUTS, especially if there’s any sort of bimodality in the posterior.

1 Like