Hello,
I would like to know the best way on how to specify that I want to use within-chain parallelization (reduce_sum) in variational inference algorithm. Code how to generate data is at the end of this issue.
At first I tried to specify the number of threads to use via the threads argument:
m1_threads <-
brm(
formula = bf0,
prior = prior0,
data = data0,
iter = 1000,
backend = "cmdstanr",
algorithm = 'meanfield',
threads = threading(threads = nthreads, grainsize = NULL,static = FALSE),
control = list(tol_rel_obj = 0.00000001),
refresh = 5,
chains = nchains,
cores = ncores
)
But it returned an error:
Compiling Stan program...
Start sampling
Error in .fun(data = .x1, seed = .x2, init = .x3, threads_per_chain = .x4, :
unused argument (threads_per_chain = .x4)
After a while I tried another approach which is complicated:
- I made native stan code from brm objects and saved the code as stan object (
make_stan_code
function). - I made native stan data structure from brm objects (
make_standata
function). - In saved
.stan
file I changedreduce_sum
toreduce_sum_static
. - Compiled stan code while enabling reduce_sum with
cpp_options = list(stan_threads = TRUE)
. - Did variational inference with
mod$variational
- Results saved to csv file. The problem is that reading this csv takes hours for large N * [number of latent variables].
I can provide you code which illustrates all 1-6 steps:
threads = 2 # needs to be tuned
# make native stan code from brm objects and save the code as stan object
my_stan_code <- make_stancode(
formula = bf0,
data = data0,
prior = prior0,
autocor = NULL,
data2 = NULL,
cov_ranef = NULL,
sparse = NULL,
sample_prior = "no",
stanvars = NULL,
stan_funs = NULL,
knots = NULL,
threads = threading(threads = threads),
normalize = getOption("brms.normalize", TRUE),
save_model = '/home/rstudio/user_projects/Neringos/psc_20220117.stan'
)
# make a native stan data structure from brm objects
my_stan_data <- make_standata(
formula = bf0,
data = data0,
#family = gaussian(),
prior = prior0,
autocor = NULL,
data2 = NULL,
cov_ranef = NULL,
sample_prior = "no",
stanvars = NULL,
threads = threading(threads = threads),
knots = NULL
)
# need to unclass to work in further steps
my_stan_data <- unclass(my_stan_data)
# need to specify missing parameter (https://mc-stan.org/docs/2_23/stan-users-guide/reduce-sum.html)
my_stan_data$grainsize <- max(c(100,ceiling( dim(data0)[1]/(1*threads))))
# Manualy change reduce_sum to reduce_sum_static in stan code
mod <- cmdstan_model(stan_file = '/home/rstudio/stan_stuff/psc_20220117.stan', compile = FALSE)
# compile stan code
(start_time <- Sys.time())
mod$compile(
quiet = FALSE,
dir = NULL,
pedantic = TRUE,
include_paths = NULL,
cpp_options = list(stan_threads = TRUE), # enabling reduce_sum
stanc_options = list(),
force_recompile = TRUE,
threads = TRUE
)
end_time <- Sys.time()
end_time - start_time
# do variational inference
(start_time <- Sys.time())
fit_vb <- mod$variational(
data = my_stan_data,
seed = NULL,
refresh = NULL,
init = NULL,
save_latent_dynamics = FALSE,
output_dir = '/home/rstudio/user_projects/Neringos/',
output_basename = NULL,
sig_figs = NULL,
threads = threads,
opencl_ids = NULL,
algorithm = NULL,
iter = 1000, # increase iter in order to increase computation time, alternatively use bigger sample
grad_samples = NULL,
elbo_samples = NULL,
eta = NULL,
adapt_engaged = NULL,
adapt_iter = NULL,
tol_rel_obj = 0.0000001,
eval_elbo = NULL,
output_samples = NULL
)
end_time <- Sys.time()
end_time - start_time
fit <- brm(formula = bf0,
prior = prior0,
data = data0,
iter = 1000,
backend = 'cmdstanr',
algorithm = 'meanfield',
tol_rel_obj = 0.00000001,
empty = TRUE
)
# reading this csv takes hours for large N * [number of latent variables]
(start_time <- Sys.time())
stanfit <- rstan::read_stan_csv(fit_vb$output_files())
end_time <- Sys.time()
end_time - start_time
fit$fit <- stanfit
fit <- rename_pars(fit)
fit
Is there a better way to specify the number of threads for a variational inference algorithm?
Code how to generate data:
rm(list = ls())
library(tidyverse)
library(cmdstanr)
library(rstan)
library(brms)
#tryCatch({cmdstanr::set_cmdstan_path('/home/cmdstanr/cmdstan-2.27.0')})
#tryCatch({cmdstanr::set_cmdstan_path('/home/cmdstanr/cmdstan-2.28.1')})
tryCatch({cmdstanr::set_cmdstan_path('/home/cmdstanr/cmdstan-2.28.2')})
cmdstan_version()
parallel::detectCores(all.tests = TRUE, logical = TRUE)
# ------------------------------------------------------------------------------
# Generate data
# ------------------------------------------------------------------------------
N <- 1000
nchains = 3
ncores = 3
nthreads = 8
data0 <-
tibble(
x = rnorm(N, 0, 1), # x = F1
x1 = rnorm(N, 1*x, 1),
x2 = rnorm(N, 0.5*x, 1),
x3 = rnorm(N, 1*x, 1),
F1 = as.numeric(NA)
)
bf0 <- bf(x1 ~ 0 + mi(F1)) +
bf(x2 ~ 0 + mi(F1)) +
bf(x3 ~ 0 + mi(F1)) +
bf(F1 | mi() ~ 1) +
set_rescor(rescor = FALSE)
prior0 <- prior(constant(1), class = "b", resp = "x1") +
prior(constant(1), class = "sigma", resp = "x1") +
prior(normal(0, 10), class = "b", resp = "x2") +
prior(constant(1), class = "sigma", resp = "x2") +
prior(normal(0, 10), class = "b", resp = "x3") +
prior(constant(1), class = "sigma", resp = "x3") +
prior(normal(0, 10), class = "Intercept", resp = "F1") +
prior(cauchy(0, 1), class = "sigma", resp = "F1")
cmdstanr version: 2.28.2
brms version: 2.16.1