Truncation with vectorized sampling statements

It’s my impression that truncation with T[L,U] doesn’t work with a vectorized sampling statement. That’s a bit unfortunate, since vectorized distribution functions are much faster. But would the following work? It relies on a similar approach as is mentioned in the Integrating out Censored Values section of the user’s guide.

data {
  int<lower=0> N;
  vector<lower=0>[N] y;
  int<lower=0,upper=1> lccdf;
}
parameters {
  real<lower=0> mu;
  real<lower=0> sigma;
}
model {
  mu ~ normal(0, 5);
  sigma ~ gamma(2, 1);
  
  if(lccdf) {
    y ~ normal(mu, sigma);
    target += -normal_lccdf(0 | mu, sigma) * N;
  } 
  else{
    for (i in 1:N) y[i] ~ normal(mu, sigma) T[0,];
  } 
}

I think that the flag lccdf switches between two ways of generating samples from the same truncated normal distribution. One way relies on a for loop, and it samples relatively slowly. I think the other way implements the same truncation, but still gets a speed benefit from a vectorized distribution function. They seem to produce similar samples, though I was surprised that the user’s guide and reference manual only show the loop approach.

Here’s an R script to run the model, using the cmdstanr interface.

library(cmdstanr)

m <- cmdstan_model("truncation.stan")

rtnorm <- function(n, mean = 0, sd = 1, lower=-Inf, upper=Inf){
  l <- pnorm(lower, mean = mean, sd = sd)
  u <- pnorm(upper, mean = mean, sd = sd)
  uni <- runif(n, l, u)
  qnorm(uni, mean=mean, sd=sd)
}

N <- 1e4
mu <- 0.5
sigma <- 5
y <- rtnorm(N, mean=mu, sd=sigma, lower=0)
d <- list(N = N, y = y, lccdf=1)

fit <- m$sample(
  data = d,
  seed = 123,
  chains = 3,
  parallel_chains = 3)

bayesplot::mcmc_pairs(fit$draws())


d2 <- list(N = N, y = y, lccdf=0)

fit2 <- m$sample(
  data = d2,
  seed = 123,
  chains = 3,
  parallel_chains = 3)

bayesplot::mcmc_pairs(fit2$draws())

That looks right to me. Neat trick. I wonder why the truncation doesn’t work with vectors. Maybe this is something we should add.

I guess does it give the same results either way? I’m not sure how to check this other than just evaluating the lpdfs and looking at the numbers.

I kinda like them both being target+= statements for this:

target += normal_lpdf(y | mu, sigma) - normal_lccdf(0 | mu, sigma) * N;

But I guess it won’t matter if the first thing drops some constants.

I wasn’t sure how to check, either. But, yes, it seems like looking at the output of the lpdfs should work. I think the following shows that the two methods are giving the same results.

data {
  int<lower=0> N;
  vector<lower=0>[N] y;
}
model {
  real lp1;
  real lp2;
  real lp1_0;
  real lp2_0;
  
  lp1_0 = target();
  y ~ normal(0, 1);
  target += -normal_lccdf(0 | 0, 1) * N;
  lp2_0 = target();
  lp1 = lp2_0 - lp1_0;
  
  for (i in 1:N) y[i] ~ normal(0, 1) T[0,];
  lp2 = target() - lp2_0;
  
  print("lp2: ", lp2);
  print("lp1: ", lp1);
  print("difference: ", fabs(lp2 - lp1));
}
generated quantities{
  real check = 0;
}

Run with

library(cmdstanr)

rtnorm <- function(n, mean = 0, sd = 1, lower=-Inf, upper=Inf){
  l <- pnorm(lower, mean = mean, sd = sd)
  u <- pnorm(upper, mean = mean, sd = sd)
  uni <- runif(n, l, u)
  qnorm(uni, mean=mean, sd=sd)
}

N <- 10
mu <- 0.5
sigma <- 5
y <- rtnorm(N, mean=mu, sd=sigma, lower=0)
d <- list(N = N, y = y)

m <- cmdstan_model("truncation_print.stan")

fit <- m$sample(
  data = d,
  seed = 123,
  chains = 1,
  iter_warmup = 2,
  iter_sampling = 0,
  fixed_param = TRUE)

fit$output()

gives

> fit$output()
[[1]]
 [1] ""                                                                                                
 [2] "method = sample (Default)"                                                                       
 [3] "  sample"                                                                                        
 [4] "    num_samples = 0"                                                                             
 [5] "    num_warmup = 2"                                                                              
 [6] "    save_warmup = 0 (Default)"                                                                   
 [7] "    thin = 1 (Default)"                                                                          
 [8] "    adapt"                                                                                       
 [9] "      engaged = 1 (Default)"                                                                     
[10] "      gamma = 0.050000000000000003 (Default)"                                                    
[11] "      delta = 0.80000000000000004 (Default)"                                                     
[12] "      kappa = 0.75 (Default)"                                                                    
[13] "      t0 = 10 (Default)"                                                                         
[14] "      init_buffer = 75 (Default)"                                                                
[15] "      term_buffer = 50 (Default)"                                                                
[16] "      window = 25 (Default)"                                                                     
[17] "    algorithm = fixed_param"                                                                     
[18] "id = 1"                                                                                          
[19] "data"                                                                                            
[20] "  file = C:/Users/psadi/AppData/Local/Temp/RtmpK8EpSN/standata-88c47e22dd8.json"                 
[21] "init = 2 (Default)"                                                                              
[22] "random"                                                                                          
[23] "  seed = 123"                                                                                    
[24] "output"                                                                                          
[25] "  file = C:/Users/psadi/AppData/Local/Temp/RtmpK8EpSN/truncation_print-202101081718-1-534bd1.csv"
[26] "  diagnostic_file =  (Default)"                                                                  
[27] "  refresh = 100 (Default)"                                                                       
[28] "  sig_figs = -1 (Default)"                                                                       
[29] ""                                                                                                
[30] "lp2: -145.471"                                                                                   
[31] "lp1: -145.471"                                                                                   
[32] "difference: 0"                                                                                   
[33] ""                                                                                                
[34] "lp2: 6.93147"                                                                                    
[35] "lp1: 6.93147"                                                                                    
[36] "difference: 0"                                                                                   
[37] ""                                                                                                
[38] ""                                                                                                
[39] " Elapsed Time: 0 seconds (Warm-up)"                                                              
[40] "               0 seconds (Sampling)"                                                             
[41] "               0 seconds (Total)"                                                                
[42] ""    

I think this means both methods have the same effect on the target.

1 Like

Cool thanks for bringing it up. Made an issue with the language (track it here if you want).

It’d be cool if we could turn this on in general.

1 Like