Autodetection of causes of divergent transitions


#1

Hello all,

I’ve seen someone claiming that it is difficult/impossible to automate the search for causes of divergent transitions because divergencies might arise for many reasons. However - the general advice for dealing with models with divergencies is to look at the pairs plot, where as I understand it plots that look like a standard, bivariate normal are very good, and funnels and other strange shapes are not good.

…but if we know that bivariate normal pair-plots are nice, can’t we then use bivariate normality tests to get at least a hint of which parameters of a model that might be problematic? See example code with Neil’s Funnel below.

library(MVN)
library(rstan)

y <- rnorm(1)
x <- rnorm(9, mean=0,sd=exp(y/2) )

stanmodelcode<-“
parameters {
real y_raw;
vector[9] x_raw;
}
transformed parameters {
real y;
vector[9] x;
y = 3.0 * y_raw;
x = exp(y/2) * x_raw;
}
model {
y_raw ~ normal(0, 1); // implies y ~ normal(0, 3)
x_raw ~ normal(0, 1); // implies x ~ normal(0, exp(y/2))
}

f <- stan(model_code=stanmodelcode,
iter = 500)

smpls <- as.data.frame(extract(f))
pars <- dim(smpls)[2]-1
normtests <- matrix(NA, ncol=pars,nrow = pars,
dimnames = list(names(smpls)[1:pars],
names(smpls)[1:pars])
)

for(i in 1:(pars-1)){
for(j in (i+1):pars){
tmp <- try(mardiaTest(cbind(smpls[,i],smpls[,j]), qqplot = F))
if(inherits(tmp,“try-error”)==F){
normtests[i,j]=tmp@chi.skew
}
}
}


#2

+1 it would be useful to have a program that could at least identify specific variables that I should examine when I have divergent transitions.


#3

One way to check these is to test(/sort) mean / std for divergent vs non-divergent draws. E.g. if they are concentrated somewhere.


#4

I haven’t checked this out extensively but I agree that it should work to provide clues. Or at least, I agree enough that we put in a grant to fund exploration/implementation of something related to automatically suggest where to look. It really starts to make sense when you have a lot of parameters. I’m not sure how well normality tests would work in practice since they aren’t all that sensitive.


#5

For the record, the mardiaTest from the mvn-package is fairly slow. A colleague of mine suggested to simply evaluate the bivariate normal likelihood of each pair or draws, using the posterior means and posterior covariance. Hence, higher values should equal a distribution that looks more like a bivariate normal.

This solution can use the “mvnfast”-package, which is using C++ code. This might scale to models with many parameters (that is, if the binorm-trick is good enough at giving hints…). See code below.

library(mvnfast)
library(rstan)

y <- rnorm(1)
x <- rnorm(9, mean=0,sd=exp(y/2) )

stanmodelcode<-"
parameters {
  real y_raw;
  vector[9] x_raw;
}
transformed parameters {
  real y;
  vector[9] x;
  y = 3.0 * y_raw;
  x = exp(y/2) * x_raw;
}
model {
  y_raw ~ normal(0, 1); // implies y ~ normal(0, 3)
  x_raw ~ normal(0, 1); // implies x ~ normal(0, exp(y/2))
}
"

f <- stan(model_code=stanmodelcode,
          iter = 500)

smpls <- as.data.frame(extract(f))
pars <- dim(smpls)[2]-1
normtests <- matrix(NA, ncol=pars,nrow = pars,
                    dimnames = list(names(smpls)[1:pars],
                                    names(smpls)[1:pars])
)

for(i in 1:(pars-1)){
  for(j in (i+1):pars){
    parvar <- cbind(smpls[,i],smpls[,j])
    tmp <- try(sum(dmvn(parvar,mu=colMeans(parvar), sigma=cov(parvar), log=T)))
    if(inherits(tmp,"try-error")==F){
      normtests[i,j]=tmp
    }
  }
}

#6

Out of curiosity I ran this on a 1,385 parameter model that otherwise checks out fine. It was plenty fast (a few minutes at most) although I could see a really large model blowing it out. There were a whole bunch of mostly false-positives driven by a few outlying points but the interesting thing is you can ID individual points that don’t match the MVN assumption. The problem with this approach is that Stan’s HMC will happily sample from plenty of non-MVN distributions. If somebody has a real-life big model with problems and wants to check it out we’d love to hear about it.


#7

…are we really missing poorly written model!? :-) That sounds like a solvable issue.

Below is Piironen and Vehtari’s Ponyshoe prior model. They have two parameterizations in their appendix - one good and one less good. Below is the bad one.
When I ran this script I got 592 divergence errors. The likelihood-sorted list of pairs of variables gave the following:

1: Combinations of the intercept and some other parameter. Pretty sure this is a false positive.
2: Combinations of a parameter + the same parameter, just transformed (i.e. x and i.e. ln x). So false positive, but easy to rule out.
3: the “lambda”-parameters. These should be reparameterised, and they are in Aki’s paper.

So yes - this method does give false positives, but also seem to flag problem parameters. See script below.

 library(mvnfast)
 library(rstan)
 rstan_options(auto_write = TRUE)
 options(mc.cores = parallel::detectCores())
 rm(list=ls())
 set.seed(123)
 library(shinystan)
 
 
 n <- 10
 d <- 9
 bvec <- rnorm(d)/rnorm(d)
 x <- matrix(n*d, nrow=n, ncol=d)
 y <- as.vector(x %*% bvec + rnorm(n))
 
 p0 <-5
 scale_icept<-100
 scale_global<-p0/((d-p0)*sqrt(n))
 nu_global <-1
 nu_local <-1 
 slab_scale <-100
 slab_df <-1
 
 stanmodelcode<-"
 data {
   int < lower =0> n; # number of observations
   int < lower =0> d; # number of predictors
   vector [n] y; # outputs
   matrix [n,d] x; # inputs
   real < lower =0> scale_icept ; # prior std for the intercept
   real < lower =0> scale_global ; # scale for the half -t prior for tau
   real < lower =1> nu_global ; # degrees of freedom for the half -t priors for tau
   real < lower =1> nu_local ; # degrees of freedom for the half -t priors for lambdas
   real < lower =0> slab_scale ; # slab scale for the regularized horseshoe
   real < lower =0> slab_df ; # slab degrees of freedom for the regularized horseshoe
 }
 parameters {
   real logsigma ;
   real beta0 ;
   vector [d] z;
   real < lower =0> tau; # global shrinkage parameter
   vector < lower =0 >[d] lambda ; # local shrinkage parameter
   real < lower =0> caux ;
 }
 transformed parameters {
   real < lower =0> sigma ; # noise std
   vector < lower =0 >[d] lambda_tilde ; # 'truncated ' local shrinkage parameter
   real < lower =0> c; # slab scale
   vector [d] beta ; # regression coefficients
   vector [n] f; # latent function values
   sigma = exp ( logsigma );
   c = slab_scale * sqrt ( caux );
   lambda_tilde = sqrt ( c^2 * square ( lambda ) ./ (c^2 + tau ^2* square ( lambda )) );
   beta = z .* lambda_tilde *tau;
   f = beta0 + x* beta ;
 }
 model {
   # half -t priors for lambdas and tau , and inverse - gamma for c^2
   z ~ normal (0, 1);
   lambda ~ student_t ( nu_local , 0, 1);
   tau ~ student_t ( nu_global , 0, scale_global * sigma );
   caux ~ inv_gamma (0.5* slab_df , 0.5* slab_df );
   beta0 ~ normal (0, scale_icept );
   y ~ normal (f, sigma );
 }
 "
 
 f   <- stan(model_code = stanmodelcode, 
             iter=2000,
             control=list(max_treedepth=13),
             chains=2
 )
 
 smpls <- as.data.frame(extract(f))
 
 pars <- dim(smpls)[2]-1
 
 normtests <- as.data.frame(matrix(NA, ncol=3,nrow = pars*(pars-1)/2))
 colnames(normtests) <- c("par1", "par2", "loglik")
 cnt<-1
 for(i in 1:(pars-1)){
   for(j in (i+1):pars){
     parvar <- cbind(smpls[,i],smpls[,j])
     normtests[cnt,1] <- names(smpls)[i]
     normtests[cnt,2] <- names(smpls)[j]
     tmp <- try(sum(dmvn(parvar,mu=colMeans(parvar), sigma=cov(parvar), log=T)))
     if(inherits(tmp,"try-error")==F){
       normtests[cnt,3]=tmp
     }
     cnt<-cnt+1
   }
 }
 normtests <- normtests[order(normtests[,3]),]
 head(normtests, n=20)

#8

I really never write any but I hear other people do. :P


#9

I think regardless of the approach the bulk of the work is here, figuring out how the false-positives/false-negatives appear. What you need is a relatively large set of models with known goot/bad parameters, and constructing that is a serious effort. I agree filtering out parameter self-comparisons and transformations is straightforward.


#10

I agree to that. However, the method above is an improvement over “looking at pairs plots” by being more systematic. But yes, thinking is still required.

Maybe if it is combined with other tricks, e.g. @ahartikainen suggestion of mean/sd of divergent and non-divergent draws it will be better at pinpointing sources of errors.

I don’t have an ocean of models producing divergence-errors lying around (I hesitate to say “unfortunately”), but I’ll try this in future work and see if it is useful.