Estimation, not marginalisation, of discrete parameters: a solution?

Received wisdom seems to be that one cannot estimate discrete parameters in Stan. If they are latent one should marginalise over them oneself, but if they are not latent - if their value is the object of interest - I haven’t seen a solution (disclaimer: I’m fairly new to Stan). Until my brilliant colleague Rob Hinch came up with the following solution. The particular problem is estimating the value of a binary-valued vector - each element is either ‘signal’ or ‘noise’ - using a vector of continuous values, each of which could have been generated from either a signal process or a noise process, and which it was is unknown. The trick is to manually calculate, for each element of the vector, P(signal | data, specific values of the model’s continuous parameters); when Stan returns values for this quantity while sampling specific values of the continuous parameters in proportion to their posterior, the mean of the resulting distribution is the posterior for the binary variable, P(signal | data).

Questions: (a) has this approach already been documented somewhere? If so apologies - I missed it. (b) does it seem correct?
Comment: if (a) = no and (b) = yes, enjoy Rob’s code is below. I’ve adapted it for pedagogy and don’t think I have introduced errors.

(Edit: this post by Bob Carpenter explains what’s going on here Rao-Blackwellization and discrete parameters in Stan | Statistical Modeling, Causal Inference, and Social Science)

The Stan code:

// Authors: Rob Hinch and Chris Wymant
// Acknowledgment: written while funded by a Li Ka Shing Foundation grant 
// awarded to Christophe Fraser.
// We have a set of observations of a value y.
// Each observation was generated by either a 'signal' process or a 'noise' 
// process; which of the two it was is unknown.
// The distributions of y for the two processes differ.
// We want to estimate the parameters of the two processes and their relative
// frequencies, i.e. the population-level parameters, which is bread-and-butter
// for mixture models. However, we also want to estimate whether each individual
// observation was signal or noise, i.e. we want the posterior for a
// binary-valued vector, which Stan does not natively support. We get around 
// this by asking Stan to explore the parameter space only of the
// population-level parameters, and manually calculating for each observation
// Prob(signal | data, population-level params). The distribution for this
// quantity, sampling the population-level params in proportion to their
// posteriors, is the posterior Prob(signal | data) as desired.
//
// We enforce that the noise distribution has lower mean than the signal 
// distribution, otherwise the problem is symmetric under an exchange of what we 
// call signal and what we call noise, and there is a two-fold redundancy in 
// all/part of parameter space that is likely to mess up inference.
//
// The specific distributions used for signal and noise are lognormal, with
// meanlog parameter mu and sdlog parameter sigma.

data {
  
  // First, actual data
  
  int<lower = 0> num_observations;
  real<lower = 0> y[num_observations];
  
  // Second, things that are not actually data but should be kept fixed over one
  // whole round of inference
  
  // Upper and lower bounds for parameter priors (assumed uniformly distributed)
  real prior_mu_signal_min;
  real<lower = prior_mu_signal_min> prior_mu_signal_max;
  real prior_mu_noise_min;
  real<lower = prior_mu_noise_min> prior_mu_noise_max;
  real<lower = 0> prior_sigma_signal_min;
  real<lower = prior_sigma_signal_min> prior_sigma_signal_max;
  real<lower = 0> prior_sigma_noise_min;
  real<lower = prior_sigma_noise_min> prior_sigma_noise_max;
  real<lower = 0, upper = 1> prior_fraction_signal_min;
  real<lower = prior_fraction_signal_min, upper = 1> prior_fraction_signal_max;
  
  // A boolean switch for whether we sample from the
  // posterior or the prior, to see how and how much the data are updating our
  // beliefs about the parameters and the kind of data they generate.
  // 0 for prior, 1 for posterior
  int<lower = 0, upper = 1> get_posterior_not_prior;
}

parameters {
  real<lower = prior_mu_signal_min,       upper = prior_mu_signal_max> mu_signal;
  real<lower = prior_mu_noise_min,        upper = fmin(prior_mu_noise_max, mu_signal)> mu_noise;
  real<lower = prior_sigma_signal_min,    upper = prior_sigma_signal_max> sigma_signal;
  real<lower = prior_sigma_noise_min,     upper = prior_sigma_noise_max> sigma_noise;
  real<lower = prior_fraction_signal_min, upper = prior_fraction_signal_max> fraction_signal;
}

// For each observation, calculate a variable lp_signal which is
// log(Prob(data | params, this observation is signal) * Prob(this observation is signal))
// and a variable lp_noise which is
// log(Prob(data | params, this observation is noise)  * Prob(this observation is noise))
// For a given observation, exponentiating each of these terms and adding them gives
// Prob(data | params)
// because P(signal) + P(noise) = 1
transformed parameters {
  real lp_signal[num_observations];
  real lp_noise[num_observations];
  for (observation in 1:num_observations) {
    lp_signal[observation] = lognormal_lpdf(y[observation] | mu_signal, sigma_signal) + log(fraction_signal);
    lp_noise[observation]  = lognormal_lpdf(y[observation] | mu_noise,  sigma_noise)  + log(1 - fraction_signal);
  }
}

model {
  
  // Priors not implicitly defined by the ranges stated at parameter declaration 
  // should be defined here. 
  

  // Calculate the likelihood if we're sampling from the posterior.
  // The hypotheses of signal and noise are mutually exclusive, so we should add
  // their likelihoods weighted by P(signal) and P(noise)
  // To increment the overall log probability density by log(a + b) where a is the
  // likelihood assuming it's signal times the probabaility of signal, i.e.
  // exp(lp_signal), and b is similar but for noise, we use
  // log_sum_exp(lp_signal, lp_noise) defined such that
  // log_sum_exp(log(a), log(b))) = log(a + b)
  if (get_posterior_not_prior) {
    for (observation in 1:num_observations) {
      target += log_sum_exp(lp_signal[observation], lp_noise[observation]);
    }
  }
  
}

// The posterior probability that a given observation is noise, conditioning on
// specific parameter values for quantities below but temporarily suppressing
// this for clarity, is
// Prob(noise | data) = Prob(data | noise) P(noise) / P(data)
//                    = Prob(data | noise) P(noise) / [
//                       P(data | noise) P(noise) + P(data | signal) P(signal) ]
//                    = exp(lp_noise) / [exp(lp_noise) + exp(lp_signal)]
//                    = exp(lp_noise) / exp(log_sum_exp(lp_signal, lp_noise))
//                    = exp(lp_noise - log_sum_exp(lp_signal, lp_noise))
// We calculate this per observation below.
// No longer suppressing the conditioning on parameter values, this quantity is
// Prob(noise | data, params); Stan will return the posterior for this over 
// parameter space, i.e. sampling in proportion to P(params | data), and so the 
// mean of the posterior for this quantity defines P(noise | data).
generated quantities {
  real<lower = 0, upper = 1> prob_is_noise_given_params[num_observations];
  for (observation in 1:num_observations) {
    prob_is_noise_given_params[observation] = exp(
      lp_noise[observation] - log_sum_exp(lp_noise[observation], lp_signal[observation]));
  }
}

Example application to simulated data in R:

# PRELIMINARIES ----------------------------------------------------------------
#
# Authors: Rob Hinch and Chris Wymant
# Acknowledgment: written while funded by a Li Ka Shing Foundation grant 
# awarded to Christophe Fraser.
#
# Abbreviations:
# df = dataframe
# num = number
# mcmc = Markov Chain Monte Carlo
# param = parameter
# mu = the meanlog parameter of a lognormal
# sigma = the sdlog parameter of a lognormal
#
# Script purpose: we simulate a set of values y from either a 'signal'
# distribution or a 'noise' distribution, both lognormals but with different
# parameters, and then get Stan to infer the parameters of those distributions
# and which values were signal and which were noise.

library(tidyverse)
library(rstan)
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)

# INPUT ------------------------------------------------------------------------
 
# The path of the Stan code associated with this R code 
file_input_stan_code <- "estimate_binary_vector.stan"

# Population-level params for the simulation
num_observations <- 1000
fraction_signal <- 0.8
mu_signal <- 10
sigma_signal <- 3
mu_noise <- 3
sigma_noise <- 0.8

# SIMULATE DATA ----------------------------------------------------------------

df <- tibble(
  is_signal = rbernoulli(num_observations, fraction_signal),
  vl_observed = if_else(is_signal,
                        rlnorm(num_observations, mu_signal, sigma_signal),
                        rlnorm(num_observations, mu_noise,  sigma_noise)))

# RUN STAN ---------------------------------------------------------------------

stan_input_posterior <- list(
  num_observations = nrow(df),
  y = df$vl_observed,
  prior_mu_signal_min = 7,
  prior_mu_signal_max = 15,
  prior_sigma_signal_min = 0.5,
  prior_sigma_signal_max = 5,
  prior_mu_noise_min = 1,
  prior_mu_noise_max = 6,
  prior_sigma_noise_min = 0.5,
  prior_sigma_noise_max = 5,
  prior_fraction_signal_min = 0,
  prior_fraction_signal_max = 1,
  get_posterior_not_prior = 1L
)
stan_input_prior <- stan_input_posterior
stan_input_prior$get_posterior_not_prior <- 0L

# Compile the Stan code
model_compiled <- stan_model(file_input_stan_code)

# Run the Stan code
num_mcmc_iterations <- 500
num_mcmc_chains <- 4
fit <- sampling(model_compiled,
                data = stan_input_posterior,
                iter = num_mcmc_iterations,
                chains = num_mcmc_chains)
fit_prior <- sampling(model_compiled,
                      data = stan_input_prior,
                      iter = num_mcmc_iterations,
                      chains = num_mcmc_chains)

# ANALYSE STAN OUTPUT ----------------------------------------------------------

# Get stan output into a long df labelled by posterior/prior
df_fit <- bind_rows(fit %>%
                      as.data.frame() %>% 
                      mutate(density_type = "posterior",
                             sample = row_number()),
                    fit_prior %>% 
                      as.data.frame() %>%
                      mutate(density_type = "prior",
                             sample = row_number())) %>%
  as_tibble() %>%
  pivot_longer(-c("sample", "density_type"), names_to = "param")

# Plot the prior and posterior for population-level (not individual-level) 
# params and the simulation truth
df_params_true <- tibble(
  param = c("fraction_signal", "mu_signal", "sigma_signal", "mu_noise",
            "sigma_noise"),
  value = c(fraction_signal, mu_signal, sigma_signal, 
            mu_noise, sigma_noise))
ggplot(df_fit %>%
  filter(!startsWith(param, "lp"),
         !startsWith(param, "prob_is_noise_given_params"))) +
  geom_histogram(aes(value, y = ..density.., fill = density_type),
                 alpha = 0.6,
                 position = "identity",
                 bins = 60) +
  geom_vline(data = df_params_true, aes(xintercept = value), color = "black") +
  facet_wrap(~param, scales = "free", nrow = 3) +
  scale_fill_brewer(palette = "Set1") +
  theme_classic() +
  coord_cartesian(expand = FALSE) +
  labs(fill = "",
       x = "param value",
       y = "probability density")

# Calculate the prior and posterior means for each observation's probability of 
# being noise. Link observations to their simulation truth. Binarise the
# means and compare against whether they really were noise.
df_fit %>% 
  filter(startsWith(param, "prob_is_noise_given_params")) %>%
  group_by(density_type, param) %>%
  summarise(mean_prob_is_noise = mean(value), .groups = "drop") %>%
  mutate(observation =
           str_match(param, "^prob_is_noise_given_params\\[([0-9]+)\\]$")[, 2] %>%
           as.integer()) %>%
  left_join(df %>% mutate(observation = row_number()),
            by = "observation") %>%
  mutate(correctly_classified = (mean_prob_is_noise < 0.5 & is_signal) |
           (mean_prob_is_noise > 0.5 & ! is_signal)) %>%
  group_by(density_type) %>%
  summarise(fraction_correctly_classified = sum(correctly_classified) / n())
1 Like

Hi @ChrisWymant nice write-up! This technique is correct and known; it amounts to normalizing the probabilities associated with the different states into a proper probability mass function.

For what it’s worth, I think that the term “latent” is not necessarily used in opposition to “of interest” (I don’t think it would be any contradiction in terms to say “in this model, the latent parameters are of interest”).

4 Likes

I just had the opportunity of looking at the multinomial_logit code in Stan math which does this exact thing

This is also described in Multinomial logistic regression - Wikipedia. The only difference being that we stay in the log scale in Stan.

3 Likes

Note also that this technique is discussed (but buried and not particularly discoverable) in the Stan Users Guide chapter on Finite Mixture Models. Section 5.3, heading “recovering posterior mixture proportions”

4 Likes