I read about discussions and previous posts about resuming sampling after interruption using cmdstan. With the help of some folks from my institution, we managed to do this using a package called DMTCP. It looks like it is working technically, but I am testing whether the inference is reliable using some simulated data with a complex model.
I wonder if anybody has any experience doing this, has ever done testing, and has any feedback and insights regarding implementation.
I also want to describe my workflow and how it works in our HPC cluster in case it is helpful to somebody else. Also, I am interested in hearing what more experienced folks think about how reliable the final inference is using this implementation.
The models I am running take weeks, and there have been so many incidences that I lost jobs running for weeks and then had to start from scratch. When they described this, it was user-friendly and straightforward, even for someone like me who doesn’t know much about it. It sounded too good to be true. Before committing to changing my workflow, I just wanted to hear others’ opinions on whether they have used this before and if it sounds legitimate and reliable. Can I still trust the inference at the very end of this process?
STEP 1
I use an R script to dump the data file and starting parameters into a JSON file and compile the model from the Stan file.
require(cmdstanr)
require(jsonlite)
setwd('/gpfs/projects/edquant/cengiz/stan_checkpointing')
################################################################
Y <- read.csv('Y.csv')
data_resp <- list(
J = length(unique(Y$id)),
I = length(unique(Y$item)),
n_obs = nrow(Y),
p_loc = Y$id,
i_loc = Y$item,
RT = log(Y$RT),
Y = Y$R
)
data_json = toJSON(data_resp, pretty=TRUE, auto_unbox=TRUE)
write(data_json, "dghirt_data.json")
################################################################
# Just to test how this works with start parameters
# Made up some starting values for parameters in the model
theta <- matrix(0,nrow=length(unique(Y$id)),ncol=2)
delta <- matrix(c(0.5,2),nrow=length(unique(Y$id)),ncol=2,byrow=T)
item <- matrix(c(0.5,4,0),nrow=length(unique(Y$item)),ncol=3,byrow=T)
H <- rep(0.5,length(unique(Y$id)))
C <- rep(0.5,length(unique(Y$item)))
start <- list(item = item,
person = theta,
delta = delta,
pH = H,
pC = C)
start_json = toJSON(start, pretty=TRUE, auto_unbox=TRUE)
write(start_json, "start.json")
################################################################
# Compile the model syntax
mod <- cmdstan_model("dghirt_cpu.stan",
cpp_options = list(stan_threads = TRUE))
STEP 2
We use the following SLURM script
#!/bin/bash
#SBATCH --account=edquant
#SBATCH --partition=cengiz
#SBATCH --job-name=dghirt_cpu
#SBATCH --output=dghirt_cpu.out
#SBATCH --error=dghirt_cpu.err
#SBATCH --time=0-00:10:00
#SBATCH --mem=16gb
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=4
#SBATCH --chdir=/gpfs/projects/edquant/cengiz/stan_checkpointing
#SBATCH --nodelist=n0303
module load dmtcp
dmtcp_launch --interval 540 ./dghirt_cpu sample num_chains=4 num_samples=500 num_warmup=500 save_warmup=true adapt delta=0.99 data file=dghirt_data.json init=start.json random seed=1234 output file=sample.csv diagnostic_file=diagnostic.csv refresh=10 num_threads=4
This runs four parallel chains (500 warm-up and 500 sampling iterations). I set the time limit to 10 minutes, so the job will be terminated before it is done. --interval 540
indicates that the system will take a snapshot at the 9th minute before the job is terminated without completion.
After submitting this batch file, it starts sampling and creates the CSV files in the working directory (sample_1.csv, sample_2.csv, sample_3.csv, and sample_4.csv). It also writes this output file.
method = sample (Default)
sample
num_samples = 500
num_warmup = 500
save_warmup = true
thin = 1 (Default)
adapt
engaged = true (Default)
gamma = 0.05 (Default)
delta = 0.99
kappa = 0.75 (Default)
t0 = 10 (Default)
init_buffer = 75 (Default)
term_buffer = 50 (Default)
window = 25 (Default)
save_metric = false (Default)
algorithm = hmc (Default)
hmc
engine = nuts (Default)
nuts
max_depth = 10 (Default)
metric = diag_e (Default)
metric_file = (Default)
stepsize = 1 (Default)
stepsize_jitter = 0 (Default)
num_chains = 4
id = 1 (Default)
data
file = dghirt_data.json
init = start.json
random
seed = 1234
output
file = sample.csv
diagnostic_file = diagnostic.csv
refresh = 10
sig_figs = -1 (Default)
profile_file = profile.csv (Default)
save_cmdstan_config = false (Default)
num_threads = 4
Gradient evaluation took 0.014865 seconds
1000 transitions using 10 leapfrog steps per transition would take 148.65 seconds.
Adjust your expectations accordingly!
Rejecting initial value:
Error evaluating the log probability at the initial value.
Exception: lognormal_lpdf: Random variable is -0.094242, but must be nonnegative! (in '/tmp/RtmpDCWw91/model-2af6d658911293.stan', line 89, column 2 to column 31)
Rejecting initial value:
Error evaluating the log probability at the initial value.
Exception: lognormal_lpdf: Random variable is -0.450966, but must be nonnegative! (in '/tmp/RtmpDCWw91/model-2af6d658911293.stan', line 89, column 2 to column 31)
Rejecting initial value:
Error evaluating the log probability at the initial value.
Exception: lognormal_lpdf: Random variable is -0.326859, but must be nonnegative! (in '/tmp/RtmpDCWw91/model-2af6d658911293.stan', line 89, column 2 to column 31)
Rejecting initial value:
Error evaluating the log probability at the initial value.
Exception: lognormal_lpdf: Random variable is -0.116623, but must be nonnegative! (in '/tmp/RtmpDCWw91/model-2af6d658911293.stan', line 89, column 2 to column 31)
Gradient evaluation took 0.00844 seconds
1000 transitions using 10 leapfrog steps per transition would take 84.4 seconds.
Adjust your expectations accordingly!
Gradient evaluation took 0.008445 seconds
1000 transitions using 10 leapfrog steps per transition would take 84.45 seconds.
Adjust your expectations accordingly!
Rejecting initial value:
Error evaluating the log probability at the initial value.
Exception: lognormal_lpdf: Random variable is -0.793705, but must be nonnegative! (in '/tmp/RtmpDCWw91/model-2af6d658911293.stan', line 89, column 2 to column 31)
Gradient evaluation took 0.008456 seconds
1000 transitions using 10 leapfrog steps per transition would take 84.56 seconds.
Adjust your expectations accordingly!
Chain [3] Iteration: 1 / 1000 [ 0%] (Warmup)
Chain [4] Iteration: 1 / 1000 [ 0%] (Warmup)
Chain [2] Iteration: 1 / 1000 [ 0%] (Warmup)
Chain [1] Iteration: 1 / 1000 [ 0%] (Warmup)
Chain [1] Iteration: 10 / 1000 [ 1%] (Warmup)
Chain [2] Iteration: 10 / 1000 [ 1%] (Warmup)
Chain [4] Iteration: 10 / 1000 [ 1%] (Warmup)
Chain [3] Iteration: 10 / 1000 [ 1%] (Warmup)
Chain [2] Iteration: 20 / 1000 [ 2%] (Warmup)
Chain [4] Iteration: 20 / 1000 [ 2%] (Warmup)
Chain [3] Iteration: 20 / 1000 [ 2%] (Warmup)
Chain [1] Iteration: 20 / 1000 [ 2%] (Warmup)
The job was terminated after 10 minutes when only 20 iterations were completed.
At 9th minute, it saves a shell file (dmtcp_restart_script.sh) and a DMTCP file (ckpt_dghirt_cpu.dmtcp) in the same working directory.
STEP 3
We submit a slightly updated SLURM script as below. This restarts the process using the .sh file saved before the previous job was terminated without completion. It still saves a new sh file every 9 minutes until the job is again terminated.
#!/bin/bash
#SBATCH --account=edquant
#SBATCH --partition=cengiz
#SBATCH --job-name=dghirt_resubmit
#SBATCH --output=/gpfs/projects/edquant/cengiz/stan_checkpointing/dghirt_resubmit2.out
#SBATCH --error=/gpfs/projects/edquant/cengiz/stan_checkpointing/dghirt_resubmit2.err
#SBATCH --time=0-00:10:00
#SBATCH --mem=16gb
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=4
#SBATCH --chdir=/gpfs/projects/edquant/cengiz/stan_checkpointing
#SBATCH --nodelist=n0303
module load dmtcp
./dmtcp_restart_script.sh --interval 540
As this new job is running, I see the CSV files from the previous job gets updated with new iterations as if nothing happened. The output file also resumes sampling from where it left as if nothing happened until the job is terminated again after 10 minutes.
e[0;32m[2025-02-11T15:15:32.924, 40000, 40002, Note] at processinfo.cpp:458 in restoreHeap; REASON='Area between saved_break and curr_break not mapped, mapping it now
_savedBrk = 8519680
curBrk = 304754688
Chain [2] Iteration: 20 / 1000 [ 2%] (Warmup)
Chain [4] Iteration: 20 / 1000 [ 2%] (Warmup)
Chain [3] Iteration: 20 / 1000 [ 2%] (Warmup)
Chain [1] Iteration: 20 / 1000 [ 2%] (Warmup)
Chain [2] Iteration: 30 / 1000 [ 3%] (Warmup)
Chain [4] Iteration: 30 / 1000 [ 3%] (Warmup)
Chain [3] Iteration: 30 / 1000 [ 3%] (Warmup)
Chain [1] Iteration: 30 / 1000 [ 3%] (Warmup)
So, for long-running jobs, if one takes a snapshot once a day, if something happens in HPC and the job is gone, or if the job hits the time limit, you don’t lose everything, and you can start from where it was based on the last checkpoint (a day before).