How to parallelize PyStan / Stan.jl / RStan runs that use different data?

Most Stan interfaces have built in parallelization of identical chains, but running parallel chains with different data doesn’t seem as straightforward (maybe with the exception of CmdStan, which works a bit differently). PyStan for instance uses the multiprocessing for its built in parallelization, but it could be used with its map function instead to parallelize runs with different data (at least with pystan<3).

Question: is there a recommended approach to run parallel chains with different data, and is there a better interface for doing that (except for CmdStan, preferably between Python, Julia, and R interfaces)?

I think multiprocessing should work in this case.

multiprocessing has some quirks with the kind of object it can parallelize: it needs to “pickleable” (which is why I think it may not work with CmdStanPy, since it’s running an external program), and it seems to have some issues with functions imported from different modules (e.g. a helper function that would wrap the Stan run and its arguments).

So it’s not as straightforward as multiprocessing anything else, and I wondered if there was any case study/workflow example where that was illustrated. But if there isn’t anything of the sort, working around these issues is doable. Thanks.

For the record, when trying to use the multiprocessing map function I am getting the following error:

RuntimeError: Exception during call to services function: AssertionError('daemonic processes are not allowed to have children')

I remember getting a similar error once with PyStan 2 when I tried running multiple chains within a Stan job and multiprocess that – I changed the Stan job to run a single chain and it worked. But I am running a single chain here (num_chains=1) so I am not sure what’s wrong. I have been using different interfaces and the package/keywords changed quite a bit since last time I used PyStan, so I may be missing something trivial.

Ok, there might be things that I’m not sure how to handle. Maybe there is some specific option when you call multiprocessing (daemon=False?).

There is one way to do this and I have used this previously.

Create a python script that does everything (and finally saves the results, for example save the inferencedata object created with arviz → remember to use unique name) and then call that script multiple times with subprocess (or manually / loop in bash).

Then when everything is done, handle data the way you need to.

But this of course assumes that your application is static and doesn’t need live results.

1 Like

Thanks. I was trying to get everything to run within a Python shell/run; I know it is (or at least was) possible because I did it before as well, but it was a single script with all functions, and now I’m organizing everything in a package-like structure so it makes sense to others as well (also the reason why I wanted it within Python, to avoid having to compile externally like with CmdStan).

Multiprocessing seems to have some quirks, and problems appear when you change things slightly (of course it may be because I don’t know exactly how it’s doing it under the hood).

Thanks for the help and suggestions, it will also help with the documentation.

Here’s my current, partially documented solution for rstan/cmdstanr. Takes a named lists of fit arguments shared for all runs snd list of named lists for arguments specific to each run. So to run one model with different data you would have args_shared = list(model = my_model), args_per_fit = list(list(data = data1), list(data = data2),...). Optionally, you can cache the fits and apply a summarizing function (also in parallel) - useful if all the fits would not fit into memory.

#' @param summarise_fun a function to process each fit. This function is run in parallel.
#'   Note that this function must be runnable in new RStudio sessions. You may
#'   use the `R_session_init` or `summarise_fun_dependencies` parameters to ensure it is loaded.
#' @param summarise_fun_dependencies a list of package names that need to be loaded for
#'   `summarise_fun` to run. IMPORTANT: when developing packages, you need to install
#'   the latest version, `devtools::load_all()` won't be enough
#'    (the packages are loaded from the default library)
#' @param cache_dir if not NULL, fits will be cached in this directory
#' @return A list of length `length(data)` containing the result of applying
#'   `summarise_fun` to each fit.
sampling_parallel <- function(args_shared, args_per_fit,
                            total_cores = getOption("mc.cores", 1),
                            cores_per_fit = NULL,
                            convert_cmdstan_fits_to_rstan = FALSE,
                            fits_in_parallel = NULL,
                            summarise_fun = NULL,
                            summarise_fun_dependencies = c(),
                            cache_fits = FALSE,
                            cache_summaries = FALSE,
                            cache_dir = NULL,
                            R_session_init_expr = NULL) {

  if(!is.list(args_shared)) {
    stop("args_shared must be a list")
  }
  if(!is.list(args_per_fit) || length(args_per_fit) <= 0) {
    stop("args_per_fit must be a non-empty list")
  }

  if(!is.null(cache_dir) && !dir.exists(cache_dir)) {
    stop(paste0("Cache dir '", cache_dir,"'  does not exist"))
  }

  if((cache_fits || cache_summaries) && is.null(cache_dir)){
    stop("Caching turned on but cache_dir not given")
  }

  if(cache_summaries && is.null(summarise_fun)) {
    stop("cache_summaries can only be used if summarise_fun is not null")
  }

  total_cores <- as.integer(total_cores)
  if(length(total_cores) > 1 || total_cores < 1 || is.na(total_cores)) {
    stop("Cores must be a single integer greater than 0")
  }

  n_fits <- length(args_per_fit)

  if("cores" %in% names(args_shared) || "num_cores" %in% names(args_shared)) {
    stop("args_shared must not specify cores or num_cores")
  }

  uses_rstan <- FALSE
  uses_cmdstan <- FALSE
  model_in_shared_args <- FALSE
  data_in_shared_args <- "data" %in% names(args_shared)

  if("model" %in% names(args_shared)) {
    if(inherits(args_shared$model, "stanmodel")) {
      uses_rstan <- TRUE
    } else if(inherits(args_shared$model, "CmdStanModel")) {
      uses_cmdstan <- TRUE
    } else {
      stop("Model in shared args is not of class 'stanmodel' or 'CmdStanModel'")
    }
    model_in_shared_args <- TRUE
  }

  for(i in 1:n_fits) {
    if(!is.list(args_per_fit[[i]])) {
      stop("All elements of args_per_fit have to be lists")
    }

    if(length(intersect(names(args_shared), names(args_per_fit[[i]]))) > 0) {
      stop(paste0("No parameters provided in args_per_fit can be given in args_shared.\n
                 Found intersection at index ", i, "."))
    }

    if("model" %in% names(args_per_fit[[i]])) {
      if(inherits(args_per_fit[[i]]$model, "stanmodel")) {
        uses_rstan <- TRUE
      } else if(inherits(args_per_fit[[i]]$model, "CmdStanModel")) {
        uses_cmdstan <- TRUE
      } else {
        stop(paste0("Model for fit id ", i," is not of class 'stanmodel' or 'CmdStanModel'"))
      }
    } else if(!model_in_shared_args) {
      stop(paste0("No model argument in shared_args and fit id ", i, " does not provide model"))
    }

    if(!data_in_shared_args && !("data" %in% names(args_per_fit[[i]]))) {
      stop(paste0("No data argument in shared_args and fit id ", i, " does not provide data"))
    }


    if("cores" %in% names(args_per_fit[[i]]) || "num_cores" %in% names(args_per_fit[[i]])) {
      stop(paste0("args_per_fit[[", i, "]] must not specify cores or num_cores"))
    }
  }

  if(is.null(fits_in_parallel)) {
    if(2 * n_fits <= total_cores) {
      fits_in_parallel <- n_fits
    } else {
      fits_in_parallel <- min(c(total_cores, n_fits))
    }
  }

  if(is.null(cores_per_fit)) {
    if(2 * n_fits <= total_cores) {
      cores_per_fit <- floor(total_cores / n_fits)
    } else {
      cores_per_fit <- 1
    }
  }

  fit_fun <- function(args, args_shared, summarise_fun,
                      convert_cmdstan_fits_to_rstan,
                      cores_per_fit,
                      cache_dir, cache_fits, cache_summaries,
                      cmdstan_fit_dir) {
    all_args <- c(args_shared, args)
    all_args$cores <- cores_per_fit

    model <- all_args$model
    all_args$model <- NULL

    summarise_fun_args <- all_args$summarise_fun_args
    all_args$summarise_fun_args <- NULL



    if(inherits(model, "stanmodel")) {
      model_code <- model@model_code
      is_rstan <- TRUE
    } else if(inherits(model, "CmdStanModel")) {
      model_code <- model$code()
      is_rstan <- FALSE
    } else {
      stop("Invalid model")
    }




    if(cache_fits || cache_summaries) {
      data <- all_args$data
      code_hash <- rlang::hash(model_code)
      data_hash <- rlang::hash(data)
    }

    summary_cached <-  FALSE
    if(!is.null(summarise_fun) && cache_summaries) {
      summary_cache_file <- paste0(cache_dir, "/summary_", code_hash, "_", data_hash, ".rds")
      if(file.exists(summary_cache_file)) {
        result <- readRDS(summary_cache_file)
        summary_cached <- TRUE
      }
    }

    if(!summary_cached) {
      fit_cached <- FALSE
      if(cache_fits) {
        fit_cache_file <- paste0(cache_dir, "/fit_", code_hash, "_", data_hash, ".rds")
        if(file.exists(fit_cache_file)) {
          fit_from_file <- readRDS(fit_cache_file)
          if((is_rstan && inherits(fit_from_file, "stanfit"))
             || (!is_rstan && !convert_cmdstan_fits_to_rstan && inherits(fit_from_file, "CmdStanMCMC"))
             || (!is_rstan && convert_cmdstan_fits_to_rstan && inherits(fit_from_file, "stanfit"))
             ) {
            fit <- fit_from_file
            fit_cached <- TRUE
          }
        }
      }


      if(!fit_cached) {
        if(inherits(model,"stanmodel")) {
          all_args_ordered <- c(list(model), all_args)
          fit <- do.call(rstan::sampling, args = all_args_ordered)
          if(!is.null(cache_dir)) {
            saveRDS(fit, fit_cache_file)
          }
        } else {
          translated_args <- list()
          for(old in names(all_args)) {
            if(old == "chains") {
              translated_args$num_chains = all_args$chains
            } else if(old == "cores") {
              translated_args$parallel_chains = all_args$cores
            } else if(old == "control") {
              if(!is.null(all_args$control$adapt_delta)) {
                translated_args$adapt_delta = all_args$control$adapt_delta
              }
              if(!is.null(all_args$control$max_treedepth)) {
                translated_args$max_depth = all_args$control$max_treedepth
              }
            } else if(old == "iter") {
              translated_args$iter_warmup = all_args$iter / 2
              translated_args$iter_sampling = all_args$iter/ 2
            } else {
              translated_args[[old]] = all_args[[old]]
            }
          }
          fit <- do.call(model$sample, args = translated_args)
          if(convert_cmdstan_fits_to_rstan) {
            fit <- rstan::read_stan_csv(fit$output_files())
            if(!is.null(cache_dir) && cache_fits) {
              saveRDS(fit, fit_cache_file)
            }
          } else {
            fit$save_output_files(cmdstan_fit_dir)
            if(!is.null(cache_dir) && cache_fits) {
              fit$save_object(fit_cache_file)
            }
          }
        }
      } # End - if(!fit_cached)

      if(!is.null(summarise_fun)) {
        result <- do.call(summarise_fun, args = c(list(fit), summarise_fun_args))
        if(!is.null(cache_dir) && cache_summaries) {
          saveRDS(result, summary_cache_file)
        }
      } else {
        result <- fit
      }

    } # End - if(!summary_cached)


    result
  }

  dependencies <- c()
  if(uses_rstan) {
    dependencies <- c(dependencies, "rstan", "Rcpp")
  }
  if(uses_cmdstan) {
    dependencies <- c(dependencies, "cmdstanr")
  }
  dependencies <- c(dependencies, summarise_fun_dependencies)

  lapply_args <- list(args_shared = args_shared,
                      summarise_fun = summarise_fun,
                      convert_cmdstan_fits_to_rstan = convert_cmdstan_fits_to_rstan,
                      cores_per_fit = cores_per_fit,
                      cache_dir = cache_dir,
                      cmdstan_fit_dir = tempdir(),
                      cache_fits = cache_fits,
                      cache_summaries = cache_summaries)

  if(fits_in_parallel == 1) {
    for(dep in dependencies) {
      suppressPackageStartupMessages(require(dep, quietly = TRUE, character.only = TRUE))
    }
    eval(R_session_init_expr, envir = environment())
    results <- do.call(lapply, args = c(
      list(X = args_per_fit, FUN = fit_fun),
      lapply_args))
  } else {
     cl <- parallel::makeCluster(fits_in_parallel, useXDR = FALSE)
     on.exit(parallel::stopCluster(cl))

    .paths <- unique(c(.libPaths(), sapply(dependencies, FUN = function(d) {
      dirname(system.file(package = d))
    })))
    .paths <- .paths[.paths != ""]
    parallel::clusterExport(cl, varlist = c(".paths","dependencies"), envir = environment())
    parallel::clusterEvalQ(cl, expr = .libPaths(.paths))
    parallel::clusterEvalQ(cl, expr =
                             for(dep in dependencies) {
                               suppressPackageStartupMessages(require(dep, quietly = TRUE, character.only = TRUE))
                             }
    )

    #    parallel::clusterExport(cl, varlist = "args_shared", envir = environment())

    parallel::clusterExport(cl, varlist =
                              c("R_session_init_expr"),
                            envir = environment())
    parallel::clusterEvalQ(cl, expr = R_session_init_expr)

    results <- do.call(parallel::parLapplyLB,
                       args = c(list(cl = cl,
                                     X = args_per_fit,
                                     fun = fit_fun,
                                     chunk.size = 1),
                                lapply_args))

  }

  results
}

Thanks for the detailed example, I’ll need some time to digest it – especially given my aversion to curly brackets – but if it makes it less cumbersome to implement parallel mapping it’s one point R against Python (CmdStan still seems to be the undefeated champion in that matter, though). Thanks again.

Maybe this could work

Thanks. I will check it out, since I am running a single chain I don’t understand why it would spawn other processes, but investigating it may take some time and it’s not on the tp of the priority list.
I’ll make sure I reply here if I am able to solve it. Thanks.