Benchmarking and Resuming sampling via DMTCP after interruption

I read about discussions and previous posts about resuming sampling after interruption using cmdstan. With the help of some folks from my institution, we managed to do this using a package called DMTCP. It looks like it is working technically, but I am testing whether the inference is reliable using some simulated data with a complex model.

I wonder if anybody has any experience doing this, has ever done testing, and has any feedback and insights regarding implementation.

I also want to describe my workflow and how it works in our HPC cluster in case it is helpful to somebody else. Also, I am interested in hearing what more experienced folks think about how reliable the final inference is using this implementation.

The models I am running take weeks, and there have been so many incidences that I lost jobs running for weeks and then had to start from scratch. When they described this, it was user-friendly and straightforward, even for someone like me who doesn’t know much about it. It sounded too good to be true. Before committing to changing my workflow, I just wanted to hear others’ opinions on whether they have used this before and if it sounds legitimate and reliable. Can I still trust the inference at the very end of this process?

STEP 1

I use an R script to dump the data file and starting parameters into a JSON file and compile the model from the Stan file.

require(cmdstanr)
require(jsonlite)

setwd('/gpfs/projects/edquant/cengiz/stan_checkpointing')
################################################################
Y <- read.csv('Y.csv')

data_resp <- list(
  J              = length(unique(Y$id)),
  I              = length(unique(Y$item)),
  n_obs          = nrow(Y),
  p_loc          = Y$id,
  i_loc          = Y$item,
  RT             = log(Y$RT),
  Y              = Y$R
)

data_json = toJSON(data_resp, pretty=TRUE, auto_unbox=TRUE)
write(data_json, "dghirt_data.json")
################################################################
# Just to test how this works with start parameters
# Made up some starting values for parameters in the model

theta <- matrix(0,nrow=length(unique(Y$id)),ncol=2)
delta <- matrix(c(0.5,2),nrow=length(unique(Y$id)),ncol=2,byrow=T)
item <- matrix(c(0.5,4,0),nrow=length(unique(Y$item)),ncol=3,byrow=T)
H <- rep(0.5,length(unique(Y$id)))
C <- rep(0.5,length(unique(Y$item)))

start <- list(item   = item,
              person = theta,
              delta  = delta,
              pH     = H,
              pC     = C)

start_json = toJSON(start, pretty=TRUE, auto_unbox=TRUE)
write(start_json, "start.json")
################################################################
# Compile the model syntax

mod <- cmdstan_model("dghirt_cpu.stan",
                     cpp_options = list(stan_threads = TRUE))

STEP 2

We use the following SLURM script

#!/bin/bash
#SBATCH --account=edquant       
#SBATCH --partition=cengiz
#SBATCH --job-name=dghirt_cpu    
#SBATCH --output=dghirt_cpu.out  
#SBATCH --error=dghirt_cpu.err   
#SBATCH --time=0-00:10:00              
#SBATCH --mem=16gb              
#SBATCH --nodes=1               
#SBATCH --ntasks-per-node=1     
#SBATCH --cpus-per-task=4       
#SBATCH --chdir=/gpfs/projects/edquant/cengiz/stan_checkpointing
#SBATCH --nodelist=n0303

module load dmtcp

dmtcp_launch --interval 540 ./dghirt_cpu sample num_chains=4 num_samples=500 num_warmup=500 save_warmup=true adapt delta=0.99 data file=dghirt_data.json init=start.json random seed=1234 output file=sample.csv diagnostic_file=diagnostic.csv refresh=10 num_threads=4

This runs four parallel chains (500 warm-up and 500 sampling iterations). I set the time limit to 10 minutes, so the job will be terminated before it is done. --interval 540 indicates that the system will take a snapshot at the 9th minute before the job is terminated without completion.

After submitting this batch file, it starts sampling and creates the CSV files in the working directory (sample_1.csv, sample_2.csv, sample_3.csv, and sample_4.csv). It also writes this output file.

method = sample (Default)
  sample
    num_samples = 500
    num_warmup = 500
    save_warmup = true
    thin = 1 (Default)
    adapt
      engaged = true (Default)
      gamma = 0.05 (Default)
      delta = 0.99
      kappa = 0.75 (Default)
      t0 = 10 (Default)
      init_buffer = 75 (Default)
      term_buffer = 50 (Default)
      window = 25 (Default)
      save_metric = false (Default)
    algorithm = hmc (Default)
      hmc
        engine = nuts (Default)
          nuts
            max_depth = 10 (Default)
        metric = diag_e (Default)
        metric_file =  (Default)
        stepsize = 1 (Default)
        stepsize_jitter = 0 (Default)
    num_chains = 4
id = 1 (Default)
data
  file = dghirt_data.json
init = start.json
random
  seed = 1234
output
  file = sample.csv
  diagnostic_file = diagnostic.csv
  refresh = 10
  sig_figs = -1 (Default)
  profile_file = profile.csv (Default)
  save_cmdstan_config = false (Default)
num_threads = 4


Gradient evaluation took 0.014865 seconds
1000 transitions using 10 leapfrog steps per transition would take 148.65 seconds.
Adjust your expectations accordingly!


Rejecting initial value:
  Error evaluating the log probability at the initial value.
Exception: lognormal_lpdf: Random variable is -0.094242, but must be nonnegative! (in '/tmp/RtmpDCWw91/model-2af6d658911293.stan', line 89, column 2 to column 31)
Rejecting initial value:
  Error evaluating the log probability at the initial value.
Exception: lognormal_lpdf: Random variable is -0.450966, but must be nonnegative! (in '/tmp/RtmpDCWw91/model-2af6d658911293.stan', line 89, column 2 to column 31)
Rejecting initial value:
  Error evaluating the log probability at the initial value.
Exception: lognormal_lpdf: Random variable is -0.326859, but must be nonnegative! (in '/tmp/RtmpDCWw91/model-2af6d658911293.stan', line 89, column 2 to column 31)
Rejecting initial value:
  Error evaluating the log probability at the initial value.
Exception: lognormal_lpdf: Random variable is -0.116623, but must be nonnegative! (in '/tmp/RtmpDCWw91/model-2af6d658911293.stan', line 89, column 2 to column 31)

Gradient evaluation took 0.00844 seconds
1000 transitions using 10 leapfrog steps per transition would take 84.4 seconds.
Adjust your expectations accordingly!



Gradient evaluation took 0.008445 seconds
1000 transitions using 10 leapfrog steps per transition would take 84.45 seconds.
Adjust your expectations accordingly!


Rejecting initial value:
  Error evaluating the log probability at the initial value.
Exception: lognormal_lpdf: Random variable is -0.793705, but must be nonnegative! (in '/tmp/RtmpDCWw91/model-2af6d658911293.stan', line 89, column 2 to column 31)

Gradient evaluation took 0.008456 seconds
1000 transitions using 10 leapfrog steps per transition would take 84.56 seconds.
Adjust your expectations accordingly!


Chain [3] Iteration:   1 / 1000 [  0%]  (Warmup)
Chain [4] Iteration:   1 / 1000 [  0%]  (Warmup)
Chain [2] Iteration:   1 / 1000 [  0%]  (Warmup)
Chain [1] Iteration:   1 / 1000 [  0%]  (Warmup)
Chain [1] Iteration:  10 / 1000 [  1%]  (Warmup)
Chain [2] Iteration:  10 / 1000 [  1%]  (Warmup)
Chain [4] Iteration:  10 / 1000 [  1%]  (Warmup)
Chain [3] Iteration:  10 / 1000 [  1%]  (Warmup)
Chain [2] Iteration:  20 / 1000 [  2%]  (Warmup)
Chain [4] Iteration:  20 / 1000 [  2%]  (Warmup)
Chain [3] Iteration:  20 / 1000 [  2%]  (Warmup)
Chain [1] Iteration:  20 / 1000 [  2%]  (Warmup)

The job was terminated after 10 minutes when only 20 iterations were completed.

At 9th minute, it saves a shell file (dmtcp_restart_script.sh) and a DMTCP file (ckpt_dghirt_cpu.dmtcp) in the same working directory.

STEP 3

We submit a slightly updated SLURM script as below. This restarts the process using the .sh file saved before the previous job was terminated without completion. It still saves a new sh file every 9 minutes until the job is again terminated.

#!/bin/bash
#SBATCH --account=edquant       
#SBATCH --partition=cengiz
#SBATCH --job-name=dghirt_resubmit    
#SBATCH --output=/gpfs/projects/edquant/cengiz/stan_checkpointing/dghirt_resubmit2.out  
#SBATCH --error=/gpfs/projects/edquant/cengiz/stan_checkpointing/dghirt_resubmit2.err   
#SBATCH --time=0-00:10:00              
#SBATCH --mem=16gb              
#SBATCH --nodes=1               
#SBATCH --ntasks-per-node=1     
#SBATCH --cpus-per-task=4       
#SBATCH --chdir=/gpfs/projects/edquant/cengiz/stan_checkpointing
#SBATCH --nodelist=n0303

module load dmtcp

./dmtcp_restart_script.sh --interval 540

As this new job is running, I see the CSV files from the previous job gets updated with new iterations as if nothing happened. The output file also resumes sampling from where it left as if nothing happened until the job is terminated again after 10 minutes.

e[0;32m[2025-02-11T15:15:32.924, 40000, 40002, Note] at processinfo.cpp:458 in restoreHeap; REASON='Area between saved_break and curr_break not mapped, mapping it now
     _savedBrk = 8519680
     curBrk = 304754688

Chain [2] Iteration:  20 / 1000 [  2%]  (Warmup)
Chain [4] Iteration:  20 / 1000 [  2%]  (Warmup)
Chain [3] Iteration:  20 / 1000 [  2%]  (Warmup)
Chain [1] Iteration:  20 / 1000 [  2%]  (Warmup)
Chain [2] Iteration:  30 / 1000 [  3%]  (Warmup)
Chain [4] Iteration:  30 / 1000 [  3%]  (Warmup)
Chain [3] Iteration:  30 / 1000 [  3%]  (Warmup)
Chain [1] Iteration:  30 / 1000 [  3%]  (Warmup)

So, for long-running jobs, if one takes a snapshot once a day, if something happens in HPC and the job is gone, or if the job hits the time limit, you don’t lose everything, and you can start from where it was based on the last checkpoint (a day before).

2 Likes

While testing this process, I interrupted the sampling several times (every couple of hours) and copied the current state of the CSV files for each chain at the end of each interruption.

The files are in this folder.

stan_checkpoints

For instance, sample_1_2.csv, sample_1_3.csv, and sample_1_4.csv are the state of samples at three different interruptions for Chain 1.

The files are identical, and every time I restart the process from the previous checkpoint, it just picks up from where it was.

Can someone please look at these files and confirm that interrupting and restarting from the previous checkpoint is no different than running the whole thing without interruption and that these samples are safe to use for inference? I appreciate any opinion from someone who knows how these things work under the hood.

Thank you.

Thanks for sharing, this is very exciting!
I remember last year there was a big update to the chkptstanr package to stop and resume sampling with brms and cmdstanr, but there were some outstanding issues regarding checkpointing during the warmup phase and I’m not sure if those have been addressed.