Reduce sum using Variational Inference algorithm

Hello,

I would like to know the best way on how to specify that I want to use within-chain parallelization (reduce_sum) in variational inference algorithm. Code how to generate data is at the end of this issue.

At first I tried to specify the number of threads to use via the threads argument:

m1_threads <- 
  brm(
    formula = bf0,
    prior = prior0,
    data = data0,
    iter = 1000,
    backend = "cmdstanr",
    algorithm = 'meanfield',
    threads = threading(threads = nthreads, grainsize = NULL,static = FALSE),
    control = list(tol_rel_obj = 0.00000001),
    refresh = 5,
    chains = nchains,
    cores = ncores
  )

But it returned an error:

Compiling Stan program...
Start sampling
Error in .fun(data = .x1, seed = .x2, init = .x3, threads_per_chain = .x4,  : 
  unused argument (threads_per_chain = .x4)

After a while I tried another approach which is complicated:

  1. I made native stan code from brm objects and saved the code as stan object (make_stan_code function).
  2. I made native stan data structure from brm objects (make_standata function).
  3. In saved .stan file I changed reduce_sum to reduce_sum_static.
  4. Compiled stan code while enabling reduce_sum with cpp_options = list(stan_threads = TRUE).
  5. Did variational inference with mod$variational
  6. Results saved to csv file. The problem is that reading this csv takes hours for large N * [number of latent variables].

I can provide you code which illustrates all 1-6 steps:


threads = 2 # needs to be tuned

# make native stan code from brm objects and save the code as stan object
my_stan_code <- make_stancode(
  formula = bf0,
  data = data0,
  prior = prior0,
  autocor = NULL,
  data2 = NULL,
  cov_ranef = NULL,
  sparse = NULL,
  sample_prior = "no",
  stanvars = NULL,
  stan_funs = NULL,
  knots = NULL,
  threads = threading(threads = threads),
  normalize = getOption("brms.normalize", TRUE),
  save_model = '/home/rstudio/user_projects/Neringos/psc_20220117.stan'
)

# make a native stan data structure from brm objects
my_stan_data <- make_standata(
  formula = bf0,
  data = data0,
  #family = gaussian(),
  prior = prior0,
  autocor = NULL,
  data2 = NULL,
  cov_ranef = NULL,
  sample_prior = "no",
  stanvars = NULL,
  threads = threading(threads = threads),
  knots = NULL
)

# need to unclass to work in further steps
my_stan_data <- unclass(my_stan_data)
# need to specify missing parameter (https://mc-stan.org/docs/2_23/stan-users-guide/reduce-sum.html)
my_stan_data$grainsize <- max(c(100,ceiling( dim(data0)[1]/(1*threads))))

# Manualy change reduce_sum to reduce_sum_static in stan code
mod <- cmdstan_model(stan_file = '/home/rstudio/stan_stuff/psc_20220117.stan', compile = FALSE)

# compile stan code
(start_time <- Sys.time())
mod$compile(
  quiet = FALSE,
  dir = NULL,
  pedantic = TRUE,
  include_paths = NULL,
  cpp_options = list(stan_threads = TRUE), # enabling reduce_sum
  stanc_options = list(),
  force_recompile = TRUE,
  threads = TRUE
)
end_time <- Sys.time()
end_time - start_time

# do variational inference
(start_time <- Sys.time())
fit_vb <- mod$variational(
  data = my_stan_data,
  seed = NULL,
  refresh = NULL,
  init = NULL,
  save_latent_dynamics = FALSE,
  output_dir = '/home/rstudio/user_projects/Neringos/',
  output_basename = NULL,
  sig_figs = NULL,
  threads = threads,
  opencl_ids = NULL,
  algorithm = NULL,
  iter = 1000, # increase iter in order to increase computation time, alternatively use bigger sample
  grad_samples = NULL,
  elbo_samples = NULL,
  eta = NULL,
  adapt_engaged = NULL,
  adapt_iter = NULL,
  tol_rel_obj = 0.0000001,
  eval_elbo = NULL,
  output_samples = NULL
)
end_time <- Sys.time()
end_time - start_time

fit <- brm(formula = bf0, 
           prior = prior0,
           data = data0, 
           iter = 1000,
           backend = 'cmdstanr',
           algorithm = 'meanfield',
           tol_rel_obj = 0.00000001,
           empty = TRUE
)

# reading this csv takes hours for large N * [number of latent variables] 
(start_time <- Sys.time())
stanfit <- rstan::read_stan_csv(fit_vb$output_files())
end_time <- Sys.time()
end_time - start_time

fit$fit <- stanfit
fit <- rename_pars(fit)
fit

Is there a better way to specify the number of threads for a variational inference algorithm?

Code how to generate data:

rm(list = ls())

library(tidyverse)
library(cmdstanr)
library(rstan)
library(brms)
#tryCatch({cmdstanr::set_cmdstan_path('/home/cmdstanr/cmdstan-2.27.0')})
#tryCatch({cmdstanr::set_cmdstan_path('/home/cmdstanr/cmdstan-2.28.1')})
tryCatch({cmdstanr::set_cmdstan_path('/home/cmdstanr/cmdstan-2.28.2')})
cmdstan_version()

parallel::detectCores(all.tests = TRUE, logical = TRUE)

# ------------------------------------------------------------------------------
# Generate data
# ------------------------------------------------------------------------------

N <- 1000
nchains = 3
ncores = 3
nthreads = 8

data0 <- 
  tibble(
    x = rnorm(N, 0, 1), # x = F1
    x1 = rnorm(N, 1*x, 1),
    x2 = rnorm(N, 0.5*x, 1),
    x3 = rnorm(N, 1*x, 1),
    F1 = as.numeric(NA)
  )


bf0 <- bf(x1 ~ 0 + mi(F1)) +
  bf(x2 ~ 0 + mi(F1)) +
  bf(x3 ~ 0 + mi(F1)) +
  bf(F1 | mi() ~ 1) + 
  set_rescor(rescor = FALSE)

prior0 <- prior(constant(1), class = "b", resp = "x1") +
  prior(constant(1), class = "sigma", resp = "x1") +
  prior(normal(0, 10), class = "b", resp = "x2") +
  prior(constant(1), class = "sigma", resp = "x2") +
  prior(normal(0, 10), class = "b", resp = "x3") +
  prior(constant(1), class = "sigma", resp = "x3") +
  prior(normal(0, 10), class = "Intercept", resp = "F1") +
  prior(cauchy(0, 1), class = "sigma", resp = "F1")

cmdstanr version: 2.28.2
brms version: 2.16.1

Whoa, that’s surprising. Have you checked the stan-csv-reading functions in cmdstanR?

Absent it working as it should in brms via the threads argument, your procedure is indeed what I’d do, but this actually looks like a bug in brms. Maybe submit a bug report here?

Hello,

Thank you for your answer. I did not check any other stan-csv-reading functions yet, but it seems like a good idea. I am a new member here, there is still much to learn.

I created a separate question in this forum about a speed of read_stan_csv function. Also, as you suggested, I submitted a bug report about an error using threads in variational inference.