Pass incomplete estimates as initial values for the next round of iterations

I am sending my Stan scripts to run on a supercomputer. However, I don’t know how to estimate how much time will the supercomputer require for computation (which is one of the parameters you need to pass to the supercomputer) and so I am often timed out. Therefore, I would like to make my code so that it saves the results of the estimation after each, say, 1000 iterations and then passes these yet incomplete estimates as initial values for the next iterations. That way, if I am timed out, I will at least have access to the calculations that were done until then and I will be able to start the estimation again from where it stopped.

However, I am trying to find a way to pass the results of the previous computations as initial values for the further computations. Simply passing the fit object to the init argument of the stan functions does not work, nor does passing the summary of the fit object. How can I ‘reshape’ the fit object so that its estimates can be used as initial values for further computations?

1 Like

is this about warmup or sampling iterations? this comes up alot, and there’s probably a lot of good stuff in previous posts about what you can/can’t achieve.

for CmdStanPy, we’re working on operationalizing this, cf. this feature request

for now, this could be scripted as follows:

  1. run the sampler for 1000 warmup iterations and 1 sampling iteration.

  2. extract the following information from the resulting CmdStanMCMC object:

  • the current best estimate of the Stan program parameters. the method stan_variable(var=<param_name>) will return a numpy.ndarray over all draws where each element of the array has the correct Stan variable structure.

  • the stepsize and metric, available as CmdStanMCMC object properties step_size and metric

  • seed, chain_id for the PRNG used by the sampler. the Stan algorithms don’t record the PRNG state, which complicates the question of how best to resume sampling - perhaps using the same seed and chain_id*2 would be adequate

  1. restart the sample - specifying seed, chain_id, step_size, metric, and initial parameter values.

there’s a certain amount of data-munging required to extract/munge/marshall this information, hence the above-mentioned PR. also problematic - CmdStan’s stansummary function can’t be used here because it only analyzes sampling iterations, not warmup, but CmdStanPy will let you export the sample to Arviz, and then you could use their diagnostics. all of this needs more investigation.

CmdStanPy provides function write_stan_json which will create JSON input files required for the CmdStanModel sample method args metric and inits.

good luck!

Bumping this. I am in the same boat (with cmdstanr). I’ve written a resumable sampling script, but extracting the best fit from N iterations and reshaping it in the way that init wants is proving challenging. How does one take the best fit and reshape it? I’d love some help here.

Does the example code in the vignette Bayesian workflow book - Birthdays help?

1 Like

It did!

For the record, I ended up writing a resumable optimizer that works moderately well:


## Reconstruct MLE from the stan CSV into a list form
reconstruct_stan_csv <- function(path) {
  # Read all the lines from the checkpoint file
  lines <- readLines(path)

  # Get all parameter names from the line starting with "lp__"
  # and separate them by commas
  param_name_start = grep("^lp__", lines)
  param_names <- lines[param_name_start] %>% 
    str_extract("lp__.*") %>% 
    str_split(",") %>% 
    unlist() %>% 
    str_trim()

  ## Count up the dots in each parameter name
  dot_count = param_names %>% 
    str_count("\\.") %>% 
    as.numeric()

  # Do the same with the values, which occur on the next line
  param_values <- lines[param_name_start + 1] %>% 
    # str_extract("lp__.*") %>% 
    str_split(",") %>% 
    unlist() %>% 
    str_trim() %>%
    as.numeric

  # A list of the number of elements in each group
  group_index_count = list()

  # Find the number of elements in each group
  for (p_id in 1:length(param_names)) {
    param_name = param_names[p_id]
    group_name = extract_group_name(param_name)
    index = extract_index(param_name)

    # If the group doesn't exist, create it
    if (!group_name %in% names(group_index_count)) {
      group_index_count[[group_name]] = index
    } else {
      # Set the count to the elementwise max of the current count and the index
      group_index_count[[group_name]] = pmax(group_index_count[[group_name]], index)
    }
  }

  # Find scalar params
  scalar_param_names = param_names[dot_count == 0]
  scalar_param_values = param_values[dot_count == 0]

  # Find vector params
  vector_param_names = param_names[dot_count == 1]
  vector_param_values = param_values[dot_count == 1]

  # Find matrix params
  matrix_param_names = param_names[dot_count == 2]
  matrix_param_values = param_values[dot_count == 2]

  # Find array params
  array_param_names = param_names[dot_count == 3]
  array_param_values = param_values[dot_count == 3]

  # Reconstruct the scalar params
  scalar_params = list()
  for (i in 1:length(scalar_param_names)) {
    scalar_params[[scalar_param_names[i]]] = scalar_param_values[i]
  }

  # Reconstruct the vector params
  vector_params = list()
  
  # Loop through each vector param if there are any
  if (length(vector_param_names) > 0) {
    for (i in 1:length(vector_param_names)) {
      # Get the group name
      group_name = extract_group_name(vector_param_names[i])

      # Get the index
      index = extract_index(vector_param_names[i])

      # Create the group if it doesn't exist
      if (!group_name %in% names(vector_params)) {
        vector_params[[group_name]] = vector(length = group_index_count[[group_name]])
      }

      # Add the value to the group
      vector_params[[group_name]][[index]] = vector_param_values[i]
    }
  }

  # Reconstruct the matrix params
  matrix_params = list()

  # Loop through each matrix param if there are any
  if (length(matrix_param_names) > 0) {
    for (i in 1:length(matrix_param_names)) {
      # Get the group name
      group_name = extract_group_name(matrix_param_names[i])

      # Get the index
      index = extract_index(matrix_param_names[i])

      # Create the group if it doesn't exist
      if (!group_name %in% names(matrix_params)) {
        matrix_params[[group_name]] = matrix(
          nrow = group_index_count[[group_name]][1],
          ncol = group_index_count[[group_name]][2])
      }

      # Add the value to the group
      matrix_params[[group_name]][[index[1], index[2]]] = matrix_param_values[i]
    }
  }

  # Reconstruct the array params
  array_params = list()

  # Loop through each array param if there are any
  if(length(array_param_names) > 0) {
    for (i in 1:length(array_param_names)) {
      # Get the group name
      group_name = extract_group_name(array_param_names[i])

      # Get the index
      index = extract_index(array_param_names[i])

      # Create the group if it doesn't exist
      if (!group_name %in% names(array_params)) {
        array_params[[group_name]] = array(
          dim = group_index_count[[group_name]])
      }

      # Add the value to the group
      array_params[[group_name]][[index[1], index[2], index[3]]] = array_param_values[i]
    }
  }

  return(c(
    scalar_params,
    vector_params,
    matrix_params,
    array_params
  ))
}

## This function handles resumable optimization for stan models.
resume_optimize <- function(
  opt,
  model, 
  data,
  init,
  checkpoint_directory,
  checkpoint_every = 100,
  threads = 1,
  tol_obj = 1e-5,
  algorithm = "lbfgs",
  init_alpha = 0.0001
) {
  # Check if checkpoint directory exists
  if (!dir.exists(checkpoint_directory)) {
    dir.create(checkpoint_directory, recursive = TRUE, showWarnings = FALSE)
  }

  # Make the directory checkpoint_directory/diagnostics
  if (!dir.exists(file.path(checkpoint_directory, "diagnostics"))) {
    dir.create(file.path(checkpoint_directory, "diagnostics"))
  }

  # Find the most recent checkpoint file
  checkpoint_files <- list.files(checkpoint_directory, pattern = "checkpoint.*")

  # Extract the checkpoint number from the file name,
  # checkpoint.1-23023.csv -> 1
  # checkpoint.2-2302512.csv -> 2
  checkpoint_numbers <- as.numeric(str_extract(checkpoint_files, "[0-9]+"))
  
  # If there are checkpoint files, start from the last checkpoint
  if (length(checkpoint_files) == 0) {
    seed = 0
  } else {
    # Check if we were given a checkpoint to start from
    is.na(opt$checkpoint_number)

    # Load the most recent checkpoint
    prev_checkpoint_number = max(checkpoint_numbers)
    checkpoint_file <- checkpoint_files[which.max(checkpoint_numbers)]

    # Read all the lines from the checkpoint file
    lines <- readLines(file.path(checkpoint_directory, checkpoint_file))

    # Use regex to extract the seed from a line that looks like
    #   seed = 0
    seed = lines[grepl("seed = [0-9]+", lines)] %>% 
      str_extract("[0-9]+") %>% 
      as.numeric()

    ## Reconstruct the parameters from the checkpoint file
    init = list(reconstruct_stan_csv(file.path(checkpoint_directory, checkpoint_file)))
  }

  # Save the checkpoint
  checkpoint_number <- ifelse(
    length(checkpoint_numbers) > 0,
    max(checkpoint_numbers, na.rm = TRUE) + 1,
    1
  )

  cat("Running optimization for checkpoint ", checkpoint_number, "\n")

  sink(file = file.path(checkpoint_directory, "diagnostics", paste0(checkpoint_number, ".txt")), type = c("output", "message"),split = FALSE)
  opt = model$optimize(
    data = data, 
    tol_obj = tol_obj, 
    algorithm = algorithm, 
    init_alpha = init_alpha, 
    threads = threads,
    init = init,
    iter = checkpoint_every,
    seed = seed
  )
  sink()

  opt$save_output_files(
    dir = checkpoint_directory,
    basename = paste0("checkpoint-", checkpoint_number)
  )

  # Read the diagnostic file, check to see if the text file contains the text "Convergence detected"
  # Return the boolean value for whether this optimization converged
  opt_converged <- readLines(file.path(checkpoint_directory, "diagnostics", paste0(checkpoint_number, ".txt"))) %>% 
    str_detect("Convergence detected") %>% 
    any()

  return(list(opt=opt, checkpoint_number=checkpoint_number, converged = opt_converged))
}

resumable_optimize <- function(
  model, 
  data,
  init,
  checkpoint_directory,
  checkpoint_every = 100,
  threads = 1,
  tol_obj = 1e-5,
  algorithm = "lbfgs",
  init_alpha = 0.0001
) {
  # Initialize opt
  opt = list(opt=NA, checkpoint_number = NA)

  # Loop until convergence
  while (TRUE) {
    # Run optimization
    opt <- resume_optimize(
      opt = opt,
      model = model,
      data = data,
      init = init,
      checkpoint_directory = checkpoint_directory,
      checkpoint_every = checkpoint_every,
      threads = threads,
      tol_obj = tol_obj,
      algorithm = algorithm,
      init_alpha = init_alpha
    )

    # Check if optimization has converged
    if (opt$converged) {
      cat("\nOptimization has converged\n")
      break
    }
  }

  return(opt$opt)
}

Are there some benefits with this compared to checkpointing with chkptstanr ?

1 Like

Wow! I didn’t know chkptstanr existed! This is great, thank you @avehtari.

1 Like