It did!
For the record, I ended up writing a resumable optimizer that works moderately well:
## Reconstruct MLE from the stan CSV into a list form
reconstruct_stan_csv <- function(path) {
# Read all the lines from the checkpoint file
lines <- readLines(path)
# Get all parameter names from the line starting with "lp__"
# and separate them by commas
param_name_start = grep("^lp__", lines)
param_names <- lines[param_name_start] %>%
str_extract("lp__.*") %>%
str_split(",") %>%
unlist() %>%
str_trim()
## Count up the dots in each parameter name
dot_count = param_names %>%
str_count("\\.") %>%
as.numeric()
# Do the same with the values, which occur on the next line
param_values <- lines[param_name_start + 1] %>%
# str_extract("lp__.*") %>%
str_split(",") %>%
unlist() %>%
str_trim() %>%
as.numeric
# A list of the number of elements in each group
group_index_count = list()
# Find the number of elements in each group
for (p_id in 1:length(param_names)) {
param_name = param_names[p_id]
group_name = extract_group_name(param_name)
index = extract_index(param_name)
# If the group doesn't exist, create it
if (!group_name %in% names(group_index_count)) {
group_index_count[[group_name]] = index
} else {
# Set the count to the elementwise max of the current count and the index
group_index_count[[group_name]] = pmax(group_index_count[[group_name]], index)
}
}
# Find scalar params
scalar_param_names = param_names[dot_count == 0]
scalar_param_values = param_values[dot_count == 0]
# Find vector params
vector_param_names = param_names[dot_count == 1]
vector_param_values = param_values[dot_count == 1]
# Find matrix params
matrix_param_names = param_names[dot_count == 2]
matrix_param_values = param_values[dot_count == 2]
# Find array params
array_param_names = param_names[dot_count == 3]
array_param_values = param_values[dot_count == 3]
# Reconstruct the scalar params
scalar_params = list()
for (i in 1:length(scalar_param_names)) {
scalar_params[[scalar_param_names[i]]] = scalar_param_values[i]
}
# Reconstruct the vector params
vector_params = list()
# Loop through each vector param if there are any
if (length(vector_param_names) > 0) {
for (i in 1:length(vector_param_names)) {
# Get the group name
group_name = extract_group_name(vector_param_names[i])
# Get the index
index = extract_index(vector_param_names[i])
# Create the group if it doesn't exist
if (!group_name %in% names(vector_params)) {
vector_params[[group_name]] = vector(length = group_index_count[[group_name]])
}
# Add the value to the group
vector_params[[group_name]][[index]] = vector_param_values[i]
}
}
# Reconstruct the matrix params
matrix_params = list()
# Loop through each matrix param if there are any
if (length(matrix_param_names) > 0) {
for (i in 1:length(matrix_param_names)) {
# Get the group name
group_name = extract_group_name(matrix_param_names[i])
# Get the index
index = extract_index(matrix_param_names[i])
# Create the group if it doesn't exist
if (!group_name %in% names(matrix_params)) {
matrix_params[[group_name]] = matrix(
nrow = group_index_count[[group_name]][1],
ncol = group_index_count[[group_name]][2])
}
# Add the value to the group
matrix_params[[group_name]][[index[1], index[2]]] = matrix_param_values[i]
}
}
# Reconstruct the array params
array_params = list()
# Loop through each array param if there are any
if(length(array_param_names) > 0) {
for (i in 1:length(array_param_names)) {
# Get the group name
group_name = extract_group_name(array_param_names[i])
# Get the index
index = extract_index(array_param_names[i])
# Create the group if it doesn't exist
if (!group_name %in% names(array_params)) {
array_params[[group_name]] = array(
dim = group_index_count[[group_name]])
}
# Add the value to the group
array_params[[group_name]][[index[1], index[2], index[3]]] = array_param_values[i]
}
}
return(c(
scalar_params,
vector_params,
matrix_params,
array_params
))
}
## This function handles resumable optimization for stan models.
resume_optimize <- function(
opt,
model,
data,
init,
checkpoint_directory,
checkpoint_every = 100,
threads = 1,
tol_obj = 1e-5,
algorithm = "lbfgs",
init_alpha = 0.0001
) {
# Check if checkpoint directory exists
if (!dir.exists(checkpoint_directory)) {
dir.create(checkpoint_directory, recursive = TRUE, showWarnings = FALSE)
}
# Make the directory checkpoint_directory/diagnostics
if (!dir.exists(file.path(checkpoint_directory, "diagnostics"))) {
dir.create(file.path(checkpoint_directory, "diagnostics"))
}
# Find the most recent checkpoint file
checkpoint_files <- list.files(checkpoint_directory, pattern = "checkpoint.*")
# Extract the checkpoint number from the file name,
# checkpoint.1-23023.csv -> 1
# checkpoint.2-2302512.csv -> 2
checkpoint_numbers <- as.numeric(str_extract(checkpoint_files, "[0-9]+"))
# If there are checkpoint files, start from the last checkpoint
if (length(checkpoint_files) == 0) {
seed = 0
} else {
# Check if we were given a checkpoint to start from
is.na(opt$checkpoint_number)
# Load the most recent checkpoint
prev_checkpoint_number = max(checkpoint_numbers)
checkpoint_file <- checkpoint_files[which.max(checkpoint_numbers)]
# Read all the lines from the checkpoint file
lines <- readLines(file.path(checkpoint_directory, checkpoint_file))
# Use regex to extract the seed from a line that looks like
# seed = 0
seed = lines[grepl("seed = [0-9]+", lines)] %>%
str_extract("[0-9]+") %>%
as.numeric()
## Reconstruct the parameters from the checkpoint file
init = list(reconstruct_stan_csv(file.path(checkpoint_directory, checkpoint_file)))
}
# Save the checkpoint
checkpoint_number <- ifelse(
length(checkpoint_numbers) > 0,
max(checkpoint_numbers, na.rm = TRUE) + 1,
1
)
cat("Running optimization for checkpoint ", checkpoint_number, "\n")
sink(file = file.path(checkpoint_directory, "diagnostics", paste0(checkpoint_number, ".txt")), type = c("output", "message"),split = FALSE)
opt = model$optimize(
data = data,
tol_obj = tol_obj,
algorithm = algorithm,
init_alpha = init_alpha,
threads = threads,
init = init,
iter = checkpoint_every,
seed = seed
)
sink()
opt$save_output_files(
dir = checkpoint_directory,
basename = paste0("checkpoint-", checkpoint_number)
)
# Read the diagnostic file, check to see if the text file contains the text "Convergence detected"
# Return the boolean value for whether this optimization converged
opt_converged <- readLines(file.path(checkpoint_directory, "diagnostics", paste0(checkpoint_number, ".txt"))) %>%
str_detect("Convergence detected") %>%
any()
return(list(opt=opt, checkpoint_number=checkpoint_number, converged = opt_converged))
}
resumable_optimize <- function(
model,
data,
init,
checkpoint_directory,
checkpoint_every = 100,
threads = 1,
tol_obj = 1e-5,
algorithm = "lbfgs",
init_alpha = 0.0001
) {
# Initialize opt
opt = list(opt=NA, checkpoint_number = NA)
# Loop until convergence
while (TRUE) {
# Run optimization
opt <- resume_optimize(
opt = opt,
model = model,
data = data,
init = init,
checkpoint_directory = checkpoint_directory,
checkpoint_every = checkpoint_every,
threads = threads,
tol_obj = tol_obj,
algorithm = algorithm,
init_alpha = init_alpha
)
# Check if optimization has converged
if (opt$converged) {
cat("\nOptimization has converged\n")
break
}
}
return(opt$opt)
}