Full Bayesian imputation forking path (aka full Bayesian chain equation)

This example is heavily inspired from the inspiring (redundant word) rethinking book. Let us assume two predictor variables and one gaussian outcome. The two predictors are co-related by the causal effects of un unobserved factor. Moreover, the two variables have missing values. A dag for this model could be:

library(ggdag)
dagify(X_obs ~ X + X_R,
               Y_obs ~ Y + Y_R,
               X ~ U,
               Y ~ U,
               X_R ~ Y,
               Y_R ~ X,
               O ~ X + Y)

To simulate the data in R for example:

library(rethinking)
library(tidyverse)
set.seed(9)
N <- 100L
mu <- 5


U <- rbinom(N, size = 2, c(.3, .4, .3))
b_x <- (-3)
b_y <- -5
U_x <- 3
U_y <- 2
k_U <- c(0, U_x, U_y)
k_x <- -.5
k_y <- -1

r_x <- 0.2
r_y <- 0.2

## Using unobserved U to sample x and y
cat_x <- U %>% map_dbl(function(u){
  if(u == 1){
    U_x
  } else {
    0
  }
})

cat_y <- U %>% map_dbl(function(u){
  if(u == 2){
    U_y
  } else {
    0
  }
})
x <- rbern( N , inv_logit(k_x + cat_x ) )
y <- rbern( N,  inv_logit(k_y + cat_y ) )

R_x <- rbern( N , r_x )
R_y <- rbern( N, r_y )
x_obs <- x
x_obs[R_x==1] <- (-9L)
y_obs <- y
y_obs[R_y==1] <- (-9L)


sigma <- 0.5
simdata <- tibble(id = 1:N,
                  eta = mu + b_x * x + b_y * y,
                  outcome = rnorm(N, eta, sigma))

Now the model in Stan:

data {
  int N;
  int R_x[N];  // is remove R_ if remove R_ = 1 otherwise R_= 0
  int R_y[N];
  int x[N];    // binary variables with -99 in missing values
  int y[N];
  
  vector[N] outcome; // outcome variable assume a gaussian outcome
}

transformed data {
  int x_id[N];
  int y_id[N];
  for(n in 1:N){
    x_id[n] = x[n] + 1;
    y_id[n] = y[n] + 1;
  }
}

parameters {
  real mu; // intercept 
  real<lower=0> sigma; 
  real b_x; // estimated effect of x
  real b_y; // estimated effect of y
  real theta_raw_x; // intercept for imputation model
  real theta_raw_y; // 
  real beta_y_x; // beta y on x in imputation model
  real beta_x_y; // beta x on y in imputation model
}

transformed parameters {
  real<lower = 0, upper = 1> theta_x_y0 = inv_logit(theta_raw_x);
  real<lower = 0, upper = 1> theta_x_y1 = inv_logit(theta_raw_x + beta_y_x);
  real<lower = 0, upper = 1> theta_y_x0 = inv_logit(theta_raw_y);
  real<lower = 0, upper = 1> theta_y_x1 = inv_logit(theta_raw_y + beta_x_y);
  vector<lower = 0, upper = 1>[2] theta_x = [theta_x_y0, theta_x_y1]';
  vector<lower = 0, upper = 1>[2] theta_y = [theta_y_x0, theta_y_x1]';
}

model {
  // priors  
  // Benoullli missing model, default priors
  target += cauchy_lpdf (theta_raw_x | 0.0, 10.0); 
  target += cauchy_lpdf (beta_y_x | 0.0, 2.5); 
  
  target += cauchy_lpdf (theta_raw_y | 0.0, 10.0); 
  target += cauchy_lpdf (beta_x_y | 0.0, 2.5); 
  
  // priors for outcome model
  target += normal_lpdf (mu | 0.0, 10.0);
  target += normal_lpdf (b_x | 0.0, 2.5);
  target += normal_lpdf (b_y | 0.0, 2.5);
  target += cauchy_lpdf (sigma | 0.0, 1.0);
  
  // likelihood for imputation (missing sub-model)
  for ( i in 1:N ) {
    if ( R_x[i] == 1  && R_y[i] == 1) { 
      
       // do nothing
    
    } else if ( R_x[i] == 1 && R_y[i] == 0) {
      
      // missing model
      target += log_sum_exp(
         log ( theta_x[y_id[i]] ) + 
         bernoulli_logit_lpmf (y[i] | theta_raw_y),
         log (1 - theta_x[y_id[i]]) +
         bernoulli_logit_lpmf (y[i] | theta_raw_y + beta_x_y) );
       
    } else if ( R_x[i] == 0 && R_y[i] == 1) {
        
      // missing model
      target += log_sum_exp(
         log (theta_y[x_id[i]]) + 
         bernoulli_logit_lpmf (x[i] | theta_raw_x),
         log (1 - theta_y[x_id[i]]) +
         bernoulli_logit_lpmf (x[i] | theta_raw_x + beta_y_x) );
       
    } else if ( R_x[i] == 0 && R_y[i] == 0 ) { // if x is not missing, calculates eta.
    
      target += bernoulli_logit_lpmf (x[i] | theta_raw_x + beta_y_x * y[i]);
      target += bernoulli_logit_lpmf (y[i] | theta_raw_y + beta_x_y * x[i]);
    } // close if else missing
  } // close for loop individuals
    
  // likelihood contribution for outcome model (splitted by missing and non-missing)
  for ( i in 1:N ) { // for each individual
  
    if ( R_x[i] == 1  && R_y[i] == 1) { // if x and y are missing
      
      target += log_sum_exp(
        [log (theta_x[2]) + log (theta_y[2]) + 
        normal_lpdf  (outcome[i] | mu + b_x * 1 + b_y * 1, sigma),
        log (theta_x[1]) + log( 1 - theta_y[2]) +
        normal_lpdf  (outcome[i] | mu + b_x * 1 + b_y * 0, sigma),
        log (1 - theta_x[2]) + log (theta_y[1]) +
        normal_lpdf  (outcome[i] | mu + b_x * 0 + b_y * 1, sigma),
        log (1 - theta_x[1]) + log (1 - theta_y[1]) +
        normal_lpdf  (outcome[i] | mu + b_x * 0 + b_y * 0, sigma)]' );
        
    } else if ( R_x[i] == 1 && R_y[i] == 0) {
      
      target += log_sum_exp(
        log(theta_x[y_id[i]]) + 
        normal_lpdf  (outcome[i] | mu + b_x * 1 + b_y * y[i], sigma),
        log(1 - theta_x[y_id[i]]) +
        normal_lpdf  (outcome[i] | mu + b_x * 0 + b_y * y[i], sigma) );
      
    } else if ( R_x[i] == 0 && R_y[i] == 1) {
      
      target += log_sum_exp(
        log(theta_y[x_id[i]]) + 
        normal_lpdf  (outcome[i] | mu + b_x * x[i] + b_y * 1, sigma),
        log(1 - theta_y[x_id[i]]) +
        normal_lpdf  (outcome[i] | mu + b_x * x[i] + b_y * 0, sigma)) ;
      
    } else if ( R_x[i] == 0 && R_y[i] == 0 ) { // if x is not missing, calculates eta.
      
      target += normal_lpdf  (outcome[i] | mu + b_x * x[i] + b_y * y[i], sigma);
      
    } else {
      
      reject("Value of R_  not recognised. Is there something wrong with the R_ (missing variables)? (events)")
      
    } // close if else missing
  } // close for loop individuals
}

All this goes well, my question then is this: there is clearly a pattern emerging here, well it is the classical forking path pattern re-emerging. Because the goal is to use this techniques with many more covariates, maybe 10? The combinations explode to 100. This will be hard to code bug-prone. Is there any better way to do this currently in Stan? In other words a more compact way to create this forking path? In case that there is no better way than hard coding the forking path would you imagine a better way in the future with new Stan releases?

Thanks you,

Hi, sorry for taking so long to respond.
I don’t completely understand your model, but what I understand is that you need to encode parameters and likelihoods over a tree of options. There is no direct way to do this currently and I don’t think there will be. You have IMHO two options to do this reasonably. Both require you to write program the enumerates all the possible options.

  1. Code generation. Write R/Python/whatever code that generates the branching Stan code for you. This is the approach taken by brms and it can work pretty well.
  2. You can also do the enumeration in Stan with recursive functions along the way of (not tested, just a sketch for single outcome, might be wrong, check my reasoning):
real my_tree_lpmf(int level, real outcome, real mu_so_far, vector beta, vector theta,int[] x, int[] is_missing) {
   if(level > size(theta)) {
       //no more variables to process
       return bernoulli_logit_lpmf(outcome[i] | mu_so_far);
   } else {
      if(is_missing[level]) {
        return log_mix(theta[level], 
           //the predictor is missing and is 1
           my_tree_lpmf(level + 1, outcome, mu_so_far + beta[level] , beta, theta, x, is_missing),
           //the predictor is missing and is 0
           my_tree_lpmf(level + 1, outcome, mu_so_far,                beta, theta, x, is_missing)
         );
       else {
           return my_tree_lpmf(level + 1, outcome, mu_so_far + beta[level] * x[level],  beta, theta, x, is_missing)
       }

      }
   }
}

Which you would then call as:

target += my_tree_lpmf(1, outcome[i], 0 /*mu_so_far*/, beta, theta, x[i], is_missing[i])

Here, level encodes the level in the tree of options we are at and notice how mu_so_far accumulates the value as we go down the tree, until you traverse all levels at which point you can compute the likelihood. Once this leaf in the tree is reached, we “go up” and this values is directly returned (when the variable is known) or aggregated with the likelihood for the other option (when the variable is not known).

It can be quite hard to reason about recursive functions if you are not used to it. But I guess there are reasonable online tutorials on recursion should you feel lost. Also note that this can obviously computationally totally explode and take eternity if you have a lot missing values, as the effort is exponential in missing values per observation.

Finally, I believe that in your code, the “do nothing” part when all predictors are missing is wrong and you should actually marginalize there over all the four possible options. The values with no predictors known are informative, because say, if all of those are 0, then the most-likely-to-miss predictors can’t have high beta values.

Some minor things:

transformed parameters {
  real<lower = 0, upper = 1> theta_x_y0 = inv_logit(theta_raw_x);
}

this should be literally equivalent to just having

parameters {
  real<lower = 0, upper = 1> theta_x_y0;
}

Stan performs those constraining transforms automatically.

Also you have very wide priors which might fail to regularize your model sufficiently. Also, when using constrained parameters for probabilities, you can put beta priors on them which should work fine.

Hope this is making at least little sense to you - please double check me :-)
Best of luck with your model!

One thing I just realized that migh also be problematic in both your model and my proposed implementation is that when handling missing values, you usually need one parameter per each missing value. In the implementation you’ve written, theta_x[1] is shared by all rows where y_id[i] == 1. So you assume that for all rows with y_id[i] == 1, either all values of x that are missing are 0 or that all values of x that are missing are 1, nothing in between!

Hi @martinmodrak, no problem thanks for the answer!

Yes that was exactlly my challenge in the beggining the solution that I have implemented for now is simple because I just need to recurse over 5 levels so I simply go for nested loops to represent the lexycografic order, with 3 examples something like:

  for(x1_imp in 0:1) {
      for(x2_imp in 0:1) {
           for(x3_imp in 0:1) {
                // compute the probabilities ....
              }
            }
          } // close for loops

This was exactly what happened and here I found super-useful the map_rect functionality, I could reduce the computation burden 100 fold using 12 threads. It took around 2, 3 hours to code and test the packing and unpacking but totally worth it.

Huh! I did not think about that TBH I tested the model with simulated data with this scheme and could recover parameters so I am not sure about this. I mean it does nothing for the imputation bit, but it does increase the log-probability later on:

if ( R_x[i] == 1  && R_y[i] == 1) { // if x and y are missing
      
      target += log_sum_exp(...);

I am not sure if I understand your point would you say that we should iterate all combinations of imputable values with all combination of possible values if all variables are missing?

PS: sorry if there are typos in the example above is over-simplified and I did not go over to check it , it might contain typos. I will update it when I find a moment then.

That is my current understanding, but I am by no means an expert.

Testing if you can recover parameters is a first step, but unless you can pass a full SBC (Simulation-based calibration, search for it to find references), it is easy to fool oneself into believing your model works when in fact there is some small pathology.

Hope that helps.

Yes I saw this before but always forget to do it because (and now comes the exuce) the classic excuse of applied scientist that are a bit impatient to apply the model to the data at hand. But this thime I will check this too, thank you!

I actually realized the thing is probably even more tricky. For each observation with n missing predictors, you IMHO need to iterate over all 2^n possible combinations of those predictors. Further, you need a simplex of size 2^n - for each observation to keep track of the relative probabilities of each combination. Having just n parameters (as I mentioned earlier) would not work, because the combinations are (IMHO) not independent. This is easy to see clearly if we assume this is just normal regression with continuous normal outcome - for example, if you have three binary predictors and observations such that:

id outcome   x1 x2 x3
A  0.1       1  -  -
B  1.3       0  -  -
... many other observations with all predictors known.

Where - means missing predictor.
Imagine the rest of your data informs \beta_1,\beta_2,\beta_3 quite well, so you know that roughly \beta_1 \simeq 0.5, \beta_1 \simeq + 10, \beta_2 \simeq -10. Then for both observations A and B, the missing predictor values can plausibly be (0,0) and (1,1), but both (1,0) and (0,1) are highly unlikely. Further, for A, (0,1) and (1,0) are almost equally (un)likely, but for B (0,1) is slightly more likely than (1,0). So you really need to model this as two separate 4-simplices (i.e. 6 actual unconstrained parameters).

I am not 100% sure how this logic transfers to binary outcomes as those contain very little information, but my hunch is that it mostly stays the same.

EDIT: It stays the same. When you convert your data to binomial outcomes (by grouping rows that have the same predictor values together), you can make almost exactly the same example as above with the normal.

Now this would be pretty tricky to code correctly in Stan, because you can’t have an array of simplices of varying sizes (it IS possible, just very tricky).

Hope I am making sense.

It all boils down to the risk you are willing to take that you are wrong :-). But the model you are trying to make is IMHO quite complex, so I believe some extra scrutiny is reasonable.