Using Stan as part of a Gibbs sampling algorithm

I am wanting to use Stan to sample from the conditional distribution in a larger model. The sampling idea is along the lines of the following Gibbs-type algorithm:

  • sample a load of discrete parameters, L, conditional on non-discrete parameters, theta; this step can be done using independent sampling since the conditional can be efficiently calculated and sampled from
  • sample theta | L using either Stan’s NUTS or HMC algorithms

(FYI, there’s no way the discrete parameters can be integrated out.)

My question is how to implement the above, particularly as I am developing a package for it. One key issue is that Stan sets a bunch of hyperparameters during warmup, and I would like to do warmup once and then use the same hyperparameters in all subsequent calls to Stan (I doubt that these optimal values would vary considerably across the various values of L).

I don’t think that rstan would allow the above to happen (since I wasn’t sure that you could pass in the various necessary hyperparameters to sampling – very happy if I’m wrong about this!). That then made me think of CmdStanR, but I’ve not used it and I wasn’t sure a) whether this would permit me to do the above and b) how this would affect users wanting to install my package.

Anyone got any suggestions on this?

In case useful to others, I have worked out a solution using Cmdstanr.

I will exemplify this for a simple model, where X\sim \text{binomial}(N, \theta), where both N and \theta are unknown. Of course, for this very simple model, I can fit it in Stan directly by marginalizing out N from the joint distribution. But here, I want to show how it’s possible to use Stan as part of a Metropolis-within-Gibbs sampling algorithm when the discrete parameters cannot easily be marginalised out.

I suppose that N\sim \text{discrete-uniform}(5, 12) and \theta\sim \text{uniform}(0,1). The Stan code to sample from \theta | N is:

data {
  int N;
  int n;
  array[n] int X;
}

parameters {
  real<lower=0, upper=1> theta;
}

model {
  X ~ binomial(N, theta);
}

To sample N | \theta, I calculate the RHS of: p(N|\theta) \propto \prod_{j=1}^n \text{binomial}(X_j | N, \theta) for the set of N values within my prior. I then normalise these by the sum of all these unnormalised probabilities, which allows me to independently sample N. The R functions that accomplish this (on the log-scale) are:

library(matrixStats)

calculate_log_p <- function(N, theta, X) {
  sum(dbinom(X, N, theta, log = TRUE))
}

calculate_all_log_p <- function(Ns, theta, X) {
  purrr::map_dbl(Ns, ~calculate_log_p(., theta, X))
}

sample_N <- function(Ns, theta, X) {
  qs <- calculate_all_log_p(Ns, theta, X)
  probs <- exp(qs - logSumExp(qs))
  sample(Ns, 1, prob=probs)
}

I then string together the two conditional updates into a Gibbs-type algorithm:

gibbs_sampler <- function(
   n_iterations,
   X, init_N, init_theta, Ns,
   init_step_size, n_nuts_per_step, model) {

  N <- init_N
  theta <- init_theta
  theta_draws <- vector(length = n_iterations)
  N_draws <- vector(length = n_iterations)

  for(i in 1:n_iterations) {

    print(i)

    data_stan <- list(
      N=N,
      n=n,
      X=X
    )

    init_fn <- function() {
      list(theta=theta)
    }

    fit <- model$sample(data=data_stan,
                      init = init_fn,
                      iter_warmup=0,
                      adapt_engaged = FALSE,
                      iter_sampling = n_nuts_per_step,
                      refresh = 0,
                      show_messages = FALSE,
                      show_exceptions = FALSE,
                      step_size = init_step_size,
                      chains=1)
    theta <- fit$draws("theta")[n_nuts_per_step, 1, 1][[1]]

    N <- sample_N(Ns, theta, X)

    theta_draws[i] <- theta
    N_draws[i] <- N
  }

  tibble(theta=theta_draws, N=N_draws)
}

The trick to getting Stan to work well for its part is to first do an initial run where we find appropriate step sizes; I use a high adapt_delta value to minimise the chance that this step size is inappropriate and leads to divergent iterations when deployed in the Gibbs algorithm (where N varies from step to step):

library(cmdstanr)

fit_init <- mod$sample(data=data_stan,
                   iter_warmup=400,
                   adapt_delta=0.99,
                   iter_sampling = 1,
                   refresh = 0,
                   show_messages = FALSE,
                   show_exceptions = FALSE,
                   chains=1)

init_step_size <- fit_init$metadata()$step_size_adaptation

Seems to work fine and returns a reasonable posterior distribution for \theta, N. R code attached runs through whole thing.
s_binomial_unknown_n.R (2.2 KB)
coin_flip.stan (131 Bytes)

3 Likes