Truncation with vectorized sampling statements

It’s my impression that truncation with T[L,U] doesn’t work with a vectorized sampling statement. That’s a bit unfortunate, since vectorized distribution functions are much faster. But would the following work? It relies on a similar approach as is mentioned in the Integrating out Censored Values section of the user’s guide.

data {
  int<lower=0> N;
  vector<lower=0>[N] y;
  int<lower=0,upper=1> lccdf;
}
parameters {
  real<lower=0> mu;
  real<lower=0> sigma;
}
model {
  mu ~ normal(0, 5);
  sigma ~ gamma(2, 1);
  
  if(lccdf) {
    y ~ normal(mu, sigma);
    target += -normal_lccdf(0 | mu, sigma) * N;
  } 
  else{
    for (i in 1:N) y[i] ~ normal(mu, sigma) T[0,];
  } 
}

I think that the flag lccdf switches between two ways of generating samples from the same truncated normal distribution. One way relies on a for loop, and it samples relatively slowly. I think the other way implements the same truncation, but still gets a speed benefit from a vectorized distribution function. They seem to produce similar samples, though I was surprised that the user’s guide and reference manual only show the loop approach.

Here’s an R script to run the model, using the cmdstanr interface.

library(cmdstanr)

m <- cmdstan_model("truncation.stan")

rtnorm <- function(n, mean = 0, sd = 1, lower=-Inf, upper=Inf){
  l <- pnorm(lower, mean = mean, sd = sd)
  u <- pnorm(upper, mean = mean, sd = sd)
  uni <- runif(n, l, u)
  qnorm(uni, mean=mean, sd=sd)
}

N <- 1e4
mu <- 0.5
sigma <- 5
y <- rtnorm(N, mean=mu, sd=sigma, lower=0)
d <- list(N = N, y = y, lccdf=1)

fit <- m$sample(
  data = d,
  seed = 123,
  chains = 3,
  parallel_chains = 3)

bayesplot::mcmc_pairs(fit$draws())


d2 <- list(N = N, y = y, lccdf=0)

fit2 <- m$sample(
  data = d2,
  seed = 123,
  chains = 3,
  parallel_chains = 3)

bayesplot::mcmc_pairs(fit2$draws())

That looks right to me. Neat trick. I wonder why the truncation doesn’t work with vectors. Maybe this is something we should add.

I guess does it give the same results either way? I’m not sure how to check this other than just evaluating the lpdfs and looking at the numbers.

I kinda like them both being target+= statements for this:

target += normal_lpdf(y | mu, sigma) - normal_lccdf(0 | mu, sigma) * N;

But I guess it won’t matter if the first thing drops some constants.

I wasn’t sure how to check, either. But, yes, it seems like looking at the output of the lpdfs should work. I think the following shows that the two methods are giving the same results.

data {
  int<lower=0> N;
  vector<lower=0>[N] y;
}
model {
  real lp1;
  real lp2;
  real lp1_0;
  real lp2_0;
  
  lp1_0 = target();
  y ~ normal(0, 1);
  target += -normal_lccdf(0 | 0, 1) * N;
  lp2_0 = target();
  lp1 = lp2_0 - lp1_0;
  
  for (i in 1:N) y[i] ~ normal(0, 1) T[0,];
  lp2 = target() - lp2_0;
  
  print("lp2: ", lp2);
  print("lp1: ", lp1);
  print("difference: ", fabs(lp2 - lp1));
}
generated quantities{
  real check = 0;
}

Run with

library(cmdstanr)

rtnorm <- function(n, mean = 0, sd = 1, lower=-Inf, upper=Inf){
  l <- pnorm(lower, mean = mean, sd = sd)
  u <- pnorm(upper, mean = mean, sd = sd)
  uni <- runif(n, l, u)
  qnorm(uni, mean=mean, sd=sd)
}

N <- 10
mu <- 0.5
sigma <- 5
y <- rtnorm(N, mean=mu, sd=sigma, lower=0)
d <- list(N = N, y = y)

m <- cmdstan_model("truncation_print.stan")

fit <- m$sample(
  data = d,
  seed = 123,
  chains = 1,
  iter_warmup = 2,
  iter_sampling = 0,
  fixed_param = TRUE)

fit$output()

gives

> fit$output()
[[1]]
 [1] ""                                                                                                
 [2] "method = sample (Default)"                                                                       
 [3] "  sample"                                                                                        
 [4] "    num_samples = 0"                                                                             
 [5] "    num_warmup = 2"                                                                              
 [6] "    save_warmup = 0 (Default)"                                                                   
 [7] "    thin = 1 (Default)"                                                                          
 [8] "    adapt"                                                                                       
 [9] "      engaged = 1 (Default)"                                                                     
[10] "      gamma = 0.050000000000000003 (Default)"                                                    
[11] "      delta = 0.80000000000000004 (Default)"                                                     
[12] "      kappa = 0.75 (Default)"                                                                    
[13] "      t0 = 10 (Default)"                                                                         
[14] "      init_buffer = 75 (Default)"                                                                
[15] "      term_buffer = 50 (Default)"                                                                
[16] "      window = 25 (Default)"                                                                     
[17] "    algorithm = fixed_param"                                                                     
[18] "id = 1"                                                                                          
[19] "data"                                                                                            
[20] "  file = C:/Users/psadi/AppData/Local/Temp/RtmpK8EpSN/standata-88c47e22dd8.json"                 
[21] "init = 2 (Default)"                                                                              
[22] "random"                                                                                          
[23] "  seed = 123"                                                                                    
[24] "output"                                                                                          
[25] "  file = C:/Users/psadi/AppData/Local/Temp/RtmpK8EpSN/truncation_print-202101081718-1-534bd1.csv"
[26] "  diagnostic_file =  (Default)"                                                                  
[27] "  refresh = 100 (Default)"                                                                       
[28] "  sig_figs = -1 (Default)"                                                                       
[29] ""                                                                                                
[30] "lp2: -145.471"                                                                                   
[31] "lp1: -145.471"                                                                                   
[32] "difference: 0"                                                                                   
[33] ""                                                                                                
[34] "lp2: 6.93147"                                                                                    
[35] "lp1: 6.93147"                                                                                    
[36] "difference: 0"                                                                                   
[37] ""                                                                                                
[38] ""                                                                                                
[39] " Elapsed Time: 0 seconds (Warm-up)"                                                              
[40] "               0 seconds (Sampling)"                                                             
[41] "               0 seconds (Total)"                                                                
[42] ""    

I think this means both methods have the same effect on the target.

Cool thanks for bringing it up. Made an issue with the language (track it here if you want).

It’d be cool if we could turn this on in general.