Cross-chain warmup adaptation using MPI

Here are R scripts for running multi-chain adaptive warmup and classic warmup and getting desired diagnostics

We want see if the adaptive warmup

  • reduces n_warmup from the default (adaptation is able to stop early)
  • reduces sum_warmup_leapfrogs (total adaptation computational cost is less)
  • reduces mean_warmup_leapfrogs (better mass-matrix adaptation)
  • reduces sum_leapfrogs and mean_leapfrogs (better mass-matrix adaptation)
  • increases bulk_ess_per_iter and tail_ess_per_iter (better mass-matrix adaptation)
  • increases bulk_ess_per_leapfrog and tail_ess_per_leapfrog (better mass-matrix adaptation)

Note! Right now we don’t care about wall clock time, and reporting anything per wall clock seconds is likely to be distraction at this point. Running with different number of chains (>=4) is fine.

modelname = "normal" #assumes there exists file normal.stan in the current working directory
data = list(D=32)

# multi-chain adaptive warmup
set_cmdstan_path("~/.cmdstanr/cmdstanmpi")
mpimodel = cmdstan_model(paste(modelname,".stan", sep=""), quiet = FALSE)
datapath = cmdstanr:::process_data(data)
system(paste("mpiexec -n 4 --tag-output ./", modelname, " sample save_warmup=1 data file=", datapath, sep=""))
stanfit <- rstan::read_stan_csv(c("mpi.0.output.csv","mpi.1.output.csv","mpi.2.output.csv","mpi.3.output.csv"))
(n_warmup = stanfit@sim$warmup)
n_iter = stanfit@sim$iter-n_warmup
sampler_params <- rstan:::get_sampler_params(stanfit, inc_warmup = TRUE)
leapfrogs = sapply(sampler_params, function(x) x[, "n_leapfrog__"])
(sum_warmup_leapfrogs = sum(leapfrogs[1:n_warmup,]))
(sum_leapfrogs = sum(leapfrogs[n_warmup+(1:n_iter),]))
(mean_warmup_leapfrogs = sum_warmup_leapfrogs/n_warmup)
(mean_leapfrogs = sum_leapfrogs/n_iter)
mon = rstan::monitor(as.array(stanfit), warmup=0, print=FALSE)
(maxrhat = max(mon[,'Rhat']))
bulk_ess_per_iter = mon[,'Bulk_ESS']/n_iter
tail_ess_per_iter = mon[,'Tail_ESS']/n_iter
bulk_ess_per_leapfrog = mon[,'Bulk_ESS']/sum_leapfrogs
tail_ess_per_leapfrog = mon[,'Tail_ESS']/sum_leapfrogs
min(bulk_ess_per_iter)
min(tail_ess_per_iter)
min(bulk_ess_per_leapfrog)
min(tail_ess_per_leapfrog)
(stepsizes = sapply(sampler_params, function(x) x[, "stepsize__"])[n_iter,])

# classic warmup
set_cmdstan_path("~/.cmdstanr/cmdstan")
model = cmdstan_model(paste(modelname,".stan", sep=""), quiet = FALSE)
fit = model$sample(data=data, save_warmup=1)
stanfit <- rstan::read_stan_csv(fit$output_files())
(n_warmup = stanfit@sim$warmup)
n_iter = stanfit@sim$iter-n_warmup
sampler_params <- rstan:::get_sampler_params(stanfit, inc_warmup = TRUE)
leapfrogs = sapply(sampler_params, function(x) x[, "n_leapfrog__"])
(sum_warmup_leapfrogs = sum(leapfrogs[1:n_warmup,]))
(sum_leapfrogs = sum(leapfrogs[n_warmup+(1:n_iter),]))
(mean_warmup_leapfrogs = sum_warmup_leapfrogs/n_warmup)
(mean_leapfrogs = sum_leapfrogs/n_iter)
mon = rstan::monitor(as.array(stanfit), warmup=0, print=FALSE)
(maxrhat = max(mon[,'Rhat']))
bulk_ess_per_iter = mon[,'Bulk_ESS']/n_iter
tail_ess_per_iter = mon[,'Tail_ESS']/n_iter
bulk_ess_per_leapfrog = mon[,'Bulk_ESS']/sum_leapfrogs
tail_ess_per_leapfrog = mon[,'Tail_ESS']/sum_leapfrogs
min(bulk_ess_per_iter)
min(tail_ess_per_iter)
min(bulk_ess_per_leapfrog)
min(tail_ess_per_leapfrog)
(stepsizes = sapply(sampler_params, function(x) x[, "stepsize__"])[n_iter,])

EDIT: process_data fix. EDIT2: monitor fix. EDIT3: added *_ess_per_leapfrog and printing, EDIT4 added stepsizes. EDIT5: added maxrhat. EDIT6: fixed monitor to show correct info.

2 Likes